diff --git a/mcx/__init__.py b/mcx/__init__.py index e94d6f41..15b9c9de 100644 --- a/mcx/__init__.py +++ b/mcx/__init__.py @@ -1,13 +1,13 @@ from mcx.model import model from mcx.model import sample_forward, seed from mcx.execution import sample, generate -from mcx.hmc import HMC +from mcx.inference.hmc import HMC __all__ = [ "model", "seed", - "HMC", "sample_forward", + "HMC", "sample", "generate", ] diff --git a/mcx/distributions/distribution.py b/mcx/distributions/distribution.py index c874cf3d..cb034c4d 100644 --- a/mcx/distributions/distribution.py +++ b/mcx/distributions/distribution.py @@ -5,6 +5,7 @@ from jax import numpy as np from .constraints import Constraint +from .utils import broadcast_batch_shape class Distribution(ABC): @@ -82,6 +83,14 @@ def sample( """ pass + def broadcast_to(self, destination_shape): + """Broadcast the distribution to a destination shape. + + This is used for instance in Neural Network where you want to + broadcast the distribution to match the layer size. + """ + self.batch_shape = broadcast_batch_shape(self.batch_shape, destination_shape) + def forward( self, rng_key: jax.random.PRNGKey, diff --git a/mcx/execution.py b/mcx/execution.py index b5f4de39..c1031473 100644 --- a/mcx/execution.py +++ b/mcx/execution.py @@ -1,53 +1,112 @@ import jax +import jax.numpy as np +from jax.flatten_util import ravel_pytree +from tqdm import tqdm +from mcx import sample_forward +from mcx.core import compile_to_logpdf + + +__all__ = ["sample", "generate"] -class sample(object): - def __init__(self, rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs): +class sample(object): + def __init__( + self, rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs + ): """ Initialize the sampling runtime. """ - self.runtime = runtime + self.program = program self.num_chains = num_chains - self.num_warmup = num_warmup self.rng_key = rng_key - initialize, build_kernel, to_trace = self.runtime - loglikelihood, initial_state, parameters, unravel_fn = initialize(rng_key, self.num_chains, **kwargs) + init, warmup, build_kernel, to_trace = self.program + + print("Initialize the sampler\n") + + validate_conditioning_variables(model, **kwargs) + loglikelihood = build_loglikelihood(model, **kwargs) + + print("Find initial states...") + initial_position, unravel_fn = get_initial_position( + rng_key, model, num_chains, **kwargs + ) + loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn) + initial_state = jax.vmap(init, in_axes=(0, None))( + initial_position, jax.value_and_grad(loglikelihood) + ) + + print("Warmup the chains...") + parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps) + + print("Compile the log-likelihood...") loglikelihood = jax.jit(loglikelihood) + print("Build and compile the inference kernel...") kernel = build_kernel(loglikelihood, parameters) - self.kernel = jax.jit(kernel) + kernel = jax.jit(kernel) - self.state = initial_state - self.unravel_fn = unravel_fn + self.kernel = kernel + self.state = state self.to_trace = to_trace + self.unravel_fn = unravel_fn - def take(self, num_samples=1000): + def run(self, num_samples=1000): _, self.rng_key = jax.random.split(self.rng_key) @jax.jit - def update_chains(states, rng_key): + def update_chains(state, rng_key): keys = jax.random.split(rng_key, self.num_chains) - new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, states) - return new_states, states + new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, state) + return new_states - rng_keys = jax.random.split(self.rng_key, num_samples) - last_state, states = jax.lax.scan(update_chains, self.state, rng_keys) - self.state = last_state + state = self.state + chain = [] - trace = self.to_trace(states, self.unravel_fn) + rng_keys = jax.random.split(self.rng_key, num_samples) + with tqdm(rng_keys, unit="samples") as progress: + progress.set_description( + "Collecting {:,} samples across {:,} chains".format( + num_samples, self.num_chains + ), + refresh=False, + ) + for key in progress: + state = update_chains(state, key) + chain.append(state) + self.state = state + + trace = self.to_trace(chain, self.unravel_fn) return trace -def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs): +def generate(rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs): """ The generator runtime """ - initialize, build_kernel, to_trace = runtime + init, warmup, build_kernel, to_trace = program + + print("Initialize the sampler\n") + + validate_conditioning_variables(model, **kwargs) + loglikelihood = build_loglikelihood(model, **kwargs) - loglikelihood, initial_state, parameters, unravel_fn = initialize(rng_key, num_chains, **kwargs) + print("Find initial states...") + initial_position, unravel_fn = get_initial_position( + rng_key, model, num_chains, **kwargs + ) + loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn) + initial_state = jax.vmap(init, in_axes=(0, None))( + initial_position, jax.value_and_grad(loglikelihood) + ) + + print("Warmup the chains...") + parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps) + + print("Compile the log-likelihood...") loglikelihood = jax.jit(loglikelihood) + print("Build and compile the inference kernel...") kernel = build_kernel(loglikelihood, parameters) kernel = jax.jit(kernel) @@ -59,3 +118,86 @@ def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs): new_states = jax.vmap(kernel)(keys, state) yield new_states + + +def validate_conditioning_variables(model, **kwargs): + """ Check that all variables passed as arguments to the sampler + are random variables or arguments to the sampler. And converserly + that all of the model definition's positional arguments are given + a value. + """ + conditioning_vars = set(kwargs.keys()) + model_randvars = set(model.random_variables) + model_args = set(model.arguments) + available_vars = model_randvars.union(model_args) + + # The variables passed as an argument to the initialization (variables + # on which the logpdf is conditionned) must be either a random variable + # or an argument to the model definition. + if not available_vars.issuperset(conditioning_vars): + unknown_vars = list(conditioning_vars.difference(available_vars)) + unknown_str = ", ".join(unknown_vars) + raise AttributeError( + "You passed a value for {} which are neither random variables nor arguments to the model definition.".format( + unknown_str + ) + ) + + # The user must provide a value for all of the model definition's + # positional arguments. + model_posargs = set(model.posargs) + if model_posargs.difference(conditioning_vars): + missing_vars = model_posargs.difference(conditioning_vars) + missing_str = ", ".join(missing_vars) + raise AttributeError( + "You need to specify a value for the following arguments: {}".format( + missing_str + ) + ) + + +def build_loglikelihood(model, **kwargs): + artifact = compile_to_logpdf(model.graph, model.namespace) + logpdf = artifact.compiled_fn + loglikelihood = jax.partial(logpdf, **kwargs) + return loglikelihood + + +def get_initial_position(rng_key, model, num_chains, **kwargs): + conditioning_vars = set(kwargs.keys()) + model_randvars = set(model.random_variables) + to_sample_vars = model_randvars.difference(conditioning_vars) + + samples = sample_forward(rng_key, model, num_samples=num_chains, **kwargs) + initial_positions = dict((var, samples[var]) for var in to_sample_vars) + + # A naive way to go about flattening the positions is to transform the + # dictionary of arrays that contain the parameter value to a list of + # dictionaries, one per position and then unravel the dictionaries. + # However, this approach takes more time than getting the samples in the + # first place. + # + # Luckily, JAX first sorts dictionaries by keys + # (https://github.com/google/jax/blob/master/jaxlib/pytree.cc) when + # raveling pytrees. We can thus ravel and stack parameter values in an + # array, sorting by key; this gives our flattened positions. We then build + # a single dictionary that contains the parameters value and use it to get + # the unraveling function using `unravel_pytree`. + positions = np.stack( + [np.ravel(samples[s]) for s in sorted(initial_positions.keys())], axis=1 + ) + + sample_position_dict = { + parameter: values[0] for parameter, values in initial_positions.items() + } + _, unravel_fn = ravel_pytree(sample_position_dict) + + return positions, unravel_fn + + +def flatten_loglikelihood(logpdf, unravel_fn): + def flattened_logpdf(array): + kwargs = unravel_fn(array) + return logpdf(**kwargs) + + return flattened_logpdf diff --git a/mcx/hmc.py b/mcx/hmc.py deleted file mode 100644 index 2ef43b32..00000000 --- a/mcx/hmc.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import NamedTuple - -import jax -from jax import numpy as np -from jax.flatten_util import ravel_pytree - -from mcx import sample_forward -from mcx.core import compile_to_logpdf -from mcx.inference.integrators import velocity_verlet, hmc_proposal -from mcx.inference.kernels import hmc_kernel, HMCState -from mcx.inference.metrics import gaussian_euclidean_metric - - -class HMCParameters(NamedTuple): - step_size: float - path_length: float - mass_matrix_sqrt: np.DeviceArray - inverse_mass_matrix: np.DeviceArray - - -def HMC(model, step_size=None, path_length=None, mass_matrix_sqrt=None, inverse_mass_matrix=None, is_mass_matrix_diagonal=True): - - artifact = compile_to_logpdf(model.graph, model.namespace) - logpdf = artifact.compiled_fn - parameters = HMCParameters(step_size, path_length, mass_matrix_sqrt, inverse_mass_matrix) - - def _flatten_logpdf(logpdf, unravel_fn): - def flattened_logpdf(array): - kwargs = unravel_fn(array) - return logpdf(**kwargs) - return flattened_logpdf - - def initialize(rng_key, num_chains, **kwargs): - """ - kwargs: a dictionary of arguments and variables we condition on and - their value. - """ - - conditioning_vars = set(kwargs.keys()) - model_randvars = set(model.random_variables) - model_args = set(model.arguments) - available_vars = model_randvars.union(model_args) - - # The variables passed as an argument to the initialization (variables - # on which the logpdf is conditionned) must be either a random variable - # or an argument to the model definition. - if not available_vars.issuperset(conditioning_vars): - unknown_vars = list(conditioning_vars.difference(available_vars)) - unknown_str = ", ".join(unknown_vars) - raise AttributeError("You passed a value for {} which are neither random variables nor arguments to the model definition.".format(unknown_str)) - - # The user must provide a value for all of the model definition's - # positional arguments. - model_posargs = set(model.posargs) - if model_posargs.difference(conditioning_vars): - missing_vars = (model_posargs.difference(conditioning_vars)) - missing_str = ", ".join(missing_vars) - raise AttributeError("You need to specify a value for the following arguments: {}".format(missing_str)) - - # Condition on data to obtain the model's log-likelihood - loglikelihood = jax.partial(logpdf, **kwargs) - - # Sample one initial position per chain from the prior - to_sample_vars = model_randvars.difference(conditioning_vars) - samples = sample_forward(rng_key, model, num_samples=num_chains, **kwargs) - initial_positions = dict((var, samples[var]) for var in to_sample_vars) - - positions = [] - for i in range(num_chains): - position = {k: value[i] for k, value in initial_positions.items()} - flat_position, unravel_fn = ravel_pytree(position) - positions.append(flat_position) - positions = np.stack(positions) - - # Transform the likelihood to use a flat array as single argument - flat_loglikelihood = _flatten_logpdf(loglikelihood, unravel_fn) - - # Compute the log probability and gradient to define initial state - logprobs, logprob_grads = jax.vmap(jax.value_and_grad(flat_loglikelihood))(positions) - initial_state = HMCState(positions, logprobs, logprob_grads) - - return flat_loglikelihood, initial_state, parameters, unravel_fn - - def build_kernel(logpdf, parameters): - """Builds the kernel that moves the chain from one point - to the next. - """ - - try: - mass_matrix_sqrt = parameters.mass_matrix_sqrt - inverse_mass_matrix = parameters.inverse_mass_matrix - step_size = parameters.step_size - except AttributeError: - AttributeError( - "The Hamiltonian Monte Carlo algorithm requires the following parameters: mass matrix, inverse mass matrix and step size." - ) - - momentum_generator, kinetic_energy = gaussian_euclidean_metric( - mass_matrix_sqrt, inverse_mass_matrix, - ) - integrator = velocity_verlet(logpdf, kinetic_energy) - proposal = hmc_proposal(integrator, step_size, path_length) - kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf) - - return kernel - - def to_trace(states_chain, ravel_fn): - """ Translate the raw chains to a format that can be understood by and - is useful to humans. - """ - return states_chain - - return initialize, build_kernel, to_trace diff --git a/mcx/inference/adaptive.py b/mcx/inference/adaptive.py index f863e43f..e3d06953 100644 --- a/mcx/inference/adaptive.py +++ b/mcx/inference/adaptive.py @@ -8,11 +8,6 @@ The Stan Manual [1]_ is a very good reference on automatic tuning of parameters used in Hamiltonian Monte Carlo. -.. note: - This is a "flat zone": values used to update the step size or the mass - matrix are 1D arrays. Raveling/unraveling logic should happen at a higher - level. - .. [1]: "HMC Algorithm Parameters", Stan Manual https://mc-stan.org/docs/2_20/reference-manual/hmc-algorithm-parameters.html """ diff --git a/mcx/inference/hmc.py b/mcx/inference/hmc.py new file mode 100644 index 00000000..fd809d9a --- /dev/null +++ b/mcx/inference/hmc.py @@ -0,0 +1,67 @@ +from typing import NamedTuple + +from jax import numpy as np + +from mcx.inference.integrators import velocity_verlet, hmc_proposal +from mcx.inference.kernels import HMCState, hmc_kernel +from mcx.inference.metrics import gaussian_euclidean_metric + + +class HMCParameters(NamedTuple): + step_size: float + num_integration_steps: float + mass_matrix_sqrt: np.DeviceArray + inverse_mass_matrix: np.DeviceArray + + +def HMC( + step_size=None, + num_integration_steps=None, + mass_matrix_sqrt=None, + inverse_mass_matrix=None, + integrator=velocity_verlet, + is_mass_matrix_diagonal=False, +): + + parameters = HMCParameters( + step_size, num_integration_steps, mass_matrix_sqrt, inverse_mass_matrix + ) + + def init(position, value_and_grad): + log_prob, log_prob_grad = value_and_grad(position) + return HMCState(position, log_prob, log_prob_grad) + + def warmup(initial_state, logpdf, num_warmup_steps): + return parameters, initial_state + + def build_kernel(logpdf, parameters): + """Builds the kernel that moves the chain from one point + to the next. + """ + + try: + mass_matrix_sqrt = parameters.mass_matrix_sqrt + inverse_mass_matrix = parameters.inverse_mass_matrix + num_integration_steps = parameters.num_integration_steps + step_size = parameters.step_size + except AttributeError: + AttributeError( + "The Hamiltonian Monte Carlo algorithm requires the following parameters: mass matrix, inverse mass matrix and step size." + ) + + momentum_generator, kinetic_energy = gaussian_euclidean_metric( + mass_matrix_sqrt, inverse_mass_matrix, + ) + integrator_step = integrator(logpdf, kinetic_energy) + proposal = hmc_proposal(integrator_step, step_size, num_integration_steps) + kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf) + + return kernel + + def to_trace(states_chain, ravel_fn): + """ Translate the raw chains to a format that can be understood by and + is useful to humans. + """ + return states_chain + + return init, warmup, build_kernel, to_trace diff --git a/mcx/inference/integrators.py b/mcx/inference/integrators.py index 694cd035..79ab50c3 100644 --- a/mcx/inference/integrators.py +++ b/mcx/inference/integrators.py @@ -2,10 +2,6 @@ manifolds. IntegratorStep does not need to compute and return the logprob! - -.. note: - This file is a "flat zone": position and momentum are one-dimensional arrays. - Any raveling/unraveling logic must be placed at a higher level. """ from functools import partial from typing import Callable, NamedTuple @@ -45,23 +41,21 @@ class IntegratorState(NamedTuple): def hmc_proposal( - integrator: Integrator, step_size: float, path_length: float = 1.0 + integrator: Integrator, step_size: float, num_integration_steps: int = 1 ) -> Proposer: """Vanilla HMC proposal. Given a path length and a step size, the HMC proposer will run the appropriate number of integration steps (typically with the velocity Verlet algorithm). - When `path_length` = `step_size` the proposer reduces to the integrator + When `num_integration_steps` = 1 the proposer reduces to the integrator step. """ - num_steps = np.clip(path_length / step_size, a_min=1).astype("int32") - @jax.jit def propose(_, init_state: Proposal) -> Proposal: new_state = jax.lax.fori_loop( - 0, num_steps, lambda i, state: integrator(state, step_size), init_state + 0, num_integration_steps, lambda i, state: integrator(state, step_size), init_state ) return new_state diff --git a/mcx/inference/kernels.py b/mcx/inference/kernels.py index b34a1d93..1624bb64 100644 --- a/mcx/inference/kernels.py +++ b/mcx/inference/kernels.py @@ -1,8 +1,4 @@ """Sampling kernels. - -.. note: - This file is a "flat zone": positions and logprobs are 1 dimensional - arrays. Raveling and unraveling logic must happen outside. """ from typing import Callable, NamedTuple, Tuple diff --git a/mcx/inference/metrics.py b/mcx/inference/metrics.py index 9b6115af..9dc51cf1 100644 --- a/mcx/inference/metrics.py +++ b/mcx/inference/metrics.py @@ -1,8 +1,4 @@ """Generate dynamics on a Euclidean Manifold or Riemannian Manifold. - -.. note: - This file is a "flat zone": positions and logprobs are 1 dimensional - arrays. Raveling and unraveling logic must happen outside. """ from typing import Callable, Tuple diff --git a/mcx/inference/runtime.py b/mcx/inference/runtime.py deleted file mode 100644 index 4b8f8602..00000000 --- a/mcx/inference/runtime.py +++ /dev/null @@ -1,20 +0,0 @@ -class Runtime(object): - def __init__(self, model, state): - self.state = state - self.model = model - self.rng_key = model.rng_key - - def logpdf_fn(self): - raise NotImplementedError - - def warmup(self, initial_states, num_iterations): - raise NotImplementedError - - def initialize(self, logpdf, num_warmup, num_samples, **kwargs): - raise NotImplementedError - - def inference_kernel(self, logpdf, warmup_state): - raise NotImplementedError - - def to_trace(self, states): - raise NotImplementedError diff --git a/mcx/inference/warmup.py b/mcx/inference/warmup.py index 7f7d7750..decccf21 100644 --- a/mcx/inference/warmup.py +++ b/mcx/inference/warmup.py @@ -1,7 +1,4 @@ """Warming up the chain. - -.. note: - This is a "flat zone": all positions are 1D array. """ from functools import partial from typing import Callable, Tuple diff --git a/mcx/layers/base.py b/mcx/layers/base.py new file mode 100644 index 00000000..0009914d --- /dev/null +++ b/mcx/layers/base.py @@ -0,0 +1,86 @@ +import mcx.distributions as md +import trax.layer as tl + + +class Layer(tl.Layer, md.Distribution): + """Base class for composable layers in a Bayesian deep model. + + Random layers are the building block of Bayesian deep models. We implement + sublayers by subclassing trax's layers. We denote here by 'random layers' + layers whose weights are drawn from a distribution. To define a random + layer, one must use the `@` symbol to indicate a random variable. + + Random layers are distributions on the space of functions, and as such must + behave like a Distribution: each layers must be associated with a + log-probability density function and a function that returns samples from + the layer's distribution. + + Here are the functionalities we want to implement before a first release: + + - [ ] ml.Serial construct that takes care of `logpdf` and `sample` + - [ ] dense layer + - [ ] softmax + - [ ] ReLu + - [ ] Deterministic transformations + - [ ] Handle the case where defining `Normal(0, 1)` in the NN definition and + outside in the example below return the same thing. + + Examples + -------- + + We can define a simple neural network with one hidden layer to classify + MNIST images as: + + >>> @mcx.model + ... def mnist(image): + ... nn <~ ml.Serial( + ... dense(400, Normal(0, 1)), + ... dense(400, Normal(0, 1)), + ... dense(10, Normal(0, 1)), + ... softmax(), + ... ) + ... probs = nn(image) + ... cat <~ Categorical(probs) + ... return cat + + One can easily add a hierarchical structure on top of layers: + + >>> @mcx.model + ... def mnist(image): + ... s <~ Poisson(.5) + ... nn <~ ml.Serial( + ... dense(400, Normal(0, s)), + ... dense(400, Normal(0, s)), + ... dense(10, Normal(0, s)), + ... softmax(), + ... ) + ... probs = nn(image) + ... cat <~ Categorical(probs) + ... return cat + + In case we want a more complex probabilistic structure of the weights, + we can pass weight transformations to the layers: + + >>> @mcx.model + ... def mnist(image): + ... b1 <~ Bernoulli(.5) + ... nn <~ ml.Serial( + ... dense(400, Normal(0, 1), fn=lambda w: w*b1), + ... dense(400, Normal(0, 1)), + ... dense(10, Normal(0, 1)), + ... softmax(), + ... ) + ... probs = nn(image) + ... cat <~ Categorical(probs) + ... return cat + + """ + + def logpdf(self, x): + """Returns the logpdf associated with the layer. Defaults to 0 for + deterministic layers. + """ + return 0 + + def sample(self): + return self.forward(self.weights) diff --git a/mcx/layers/combinators.py b/mcx/layers/combinators.py new file mode 100644 index 00000000..058313c1 --- /dev/null +++ b/mcx/layers/combinators.py @@ -0,0 +1,102 @@ +import mcx.distributions as md + +try: + import trax.layers as tl +except ImportError: + raise ImportError( + "You need to install trax (`pip install trax`) to use the neural networks functionalities." + ) + + +class Serial(tl.Serial, md.Distribution): + """Combinator that applies layers sequentially. + + >>> def mnist(image): + ... nn <~ ml.Serial( + ... dense(400, Normal(0, 1)), + ... dense(400, Normal(0, 1)), + ... dense(10, Normal(0, 1)), + ... softmax(), + ... ) + ... p = nn(image) + ... cat <~ Categorical(p) + ... return cat + + Computing the logpdf: + + >>> def mnist_logpdf(image, w1, w2, w3, cat): + ... nn_logpdf = nn.logpdf(w1, w2, w3) + ... # set layerweights to w1, w2, w3 + ... p = nn(image) + ... cat_logpdf = Categorical(p).logpdf(cat) + ... return nn_logpdf + cat_logpdf + + Sampling from the distribution: + + >>> def mnist_sample(rng_key, image, shape=()): + ... nn = ml.Serial( + ... dense(400, Normal(0, 1)), + ... dense(400, Normal(0, 1)), + ... dense(10, Normal(0, 1)), + ... softmax(), + ... ).sample(rng_key, shape) + ... p = nn(image) + ... cat = Categorical(p).sample(rng_key, shape) + ... return cat + + """ + + def __init__(self, *sublayers): + super(tl.Serial, self).__init__() + + def logpdf(self, *weights): + """Should compute the logpdf of the weights as well + as setting the weights to the passed value. + + Example + ------- + + The following model: + + >>> @mcx.model + ... def prior(): + ... nn <~ ml.Serial( + ... dense(10, Normal(0, 1)), + ... relu(), + ... dense(20, Normal(0, 1)), + ... softmax(), + ... ) + ... return nn + + Should be associated with the following logpdf function: + + >>> def prior_logpdf(w1, w2): + ... logpdf += Normal(0, 1).logpdf(w1) + ... logpdf += Normal(0, 1).logpdf(w2) + ... return logpdf + + Which is obtained by summing the log-probability associated with + each layer. + """ + logpdf = 0 + for layer, w in zip(self.sublayers, weights): + logpdf += layer.logpdf(w) + + def sample(self, rng_key, sample_shape): + """Should make sure that the forward pass generates samples + every time a new element is passed as an input. + + >>> nn.sample(rng_key, sample_shape) + + Should return a sample of the `nn` function; this sets the value + of the weights for the next forward pass. + + -> Overloads forward() ? i.e. returns a function. + In which case, to be consistent, we should + implement forward() for our model as getting a sample of + the generated variable. + """ + self.rng_key = rng_key + self.sample_shape = sample_shape + self.forward_with_state = self.sample_with_state + return self diff --git a/mcx/layers/core.py b/mcx/layers/core.py new file mode 100644 index 00000000..7137c6d7 --- /dev/null +++ b/mcx/layers/core.py @@ -0,0 +1,42 @@ +from typing import Callable, Optional + +import jax.numpy as np + +import mcx.distributions as md +import trax.layers as tl + + +class Dense(tl.Dense, md.Distribution): + """Dense layer which basically computes: + + ..math: `X * w + b` + """ + + def __init__( + self, + n_units: int, + distribution: Optional[md.Distribution] = None, + transform: Optional[Callable] = None, + kernel_initializer=tl.init.GlorotUniformInitializer(), + bias_initializer=tl.init.RandomNormalInitializer(1e-6), + ): + super(tl.Dense, self).__init__(n_units, kernel_initializer, bias_initializer) + self.distribution = distribution + self.transform = transform if transform else lambda x: x + + def forward(self, x, weights): + w, b = weights + w_tilde = self.transform(w) + return np.dot(x, w_tilde) + b + + def sample(self, x, weights, rng_key, sample_shape): + """Used during forward sample""" + w, b = weights + w_tilde = self.distribution.sample( + rng_key, sample_shape + ) # this may be problematic + return np.dot(x, w_tilde) + b + + def logpdf(self, weights): + w, b = weights + return self.distribution.logpdf(w) diff --git a/mcx/model.py b/mcx/model.py index 7f61fed4..a46d1f54 100644 --- a/mcx/model.py +++ b/mcx/model.py @@ -357,9 +357,8 @@ def sample_forward(rng_key, model: model, num_samples=1000, **kwargs) -> Dict: sampler_fn = jax.jit(sampler_fn) samples = jax.vmap(sampler_fn, in_axes=in_axes, out_axes=out_axes)(*sampler_args) - trace = {} - for arg, arg_samples in zip(model.variables, samples): - trace[arg] = numpy.asarray(arg_samples).T.squeeze() + trace = {arg: numpy.asarray(arg_samples).T.squeeze() for arg, arg_samples in zip(model.variables, samples)} + return trace