Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
separated programs and runtimes neatly
The boundary between programs (sampling algorithms) and runtimes
(executors) was not very clear. I remove any dependance on the model
from the program and responsibilities are now clear. I also improved the
performance of the creation of initial states.
  • Loading branch information
rlouf committed Mar 30, 2020
commit cd51114e0b5057f01eacdf8942ee4f9ec677e29c
4 changes: 2 additions & 2 deletions mcx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from mcx.model import model
from mcx.model import sample_forward, seed
from mcx.execution import sample, generate
from mcx.hmc import HMC
from mcx.inference.hmc import HMC

__all__ = [
"model",
"seed",
"HMC",
"sample_forward",
"HMC",
"sample",
"generate",
]
100 changes: 60 additions & 40 deletions mcx/execution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import time

import jax
import jax.numpy as np
from jax.flatten_util import ravel_pytree
Expand All @@ -9,52 +7,44 @@
from mcx.core import compile_to_logpdf


__all__ = ["sample", "generate"]


class sample(object):
def __init__(
self, rng_key, model, program, num_warmup=1000, num_chains=4, **kwargs
self, rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs
):
""" Initialize the sampling runtime.
"""
self.program = program
self.num_chains = num_chains
self.num_warmup = num_warmup
self.rng_key = rng_key

print("Initialize the sampler")
print("----------------------")

init, warmup, build_kernel, to_trace = self.program

check_conditioning_variables(model, **kwargs)
print("Initialize the sampler\n")

validate_conditioning_variables(model, **kwargs)
loglikelihood = build_loglikelihood(model, **kwargs)

print("Find initial positions...", end=" ")
start = time.time()
positions, unravel_fn = get_initial_position(
print("Find initial states...")
initial_position, unravel_fn = get_initial_position(
rng_key, model, num_chains, **kwargs
)
print("Found in {:.2f}s".format(time.time() - start))

flat_loglikelihood = _flatten_logpdf(loglikelihood, unravel_fn)

print("Compute initial states...", end=" ")
start = time.time()
value_and_grad = jax.value_and_grad(flat_loglikelihood)
initial_state = jax.vmap(init, in_axes=(0, None))(positions, value_and_grad)
print("Computed in {:.2f}s".format(time.time() - start))
loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn)
initial_state = jax.vmap(init, in_axes=(0, None))(
initial_position, jax.value_and_grad(loglikelihood)
)

parameters, state = warmup(initial_state, flat_loglikelihood)
print("Warmup the chains...")
parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps)

print("Compiling the log-likelihood...", end=" ")
start = time.time()
loglikelihood = jax.jit(flat_loglikelihood)
print("Compiled in {:.2f}s".format(time.time() - start))
print("Compile the log-likelihood...")
loglikelihood = jax.jit(loglikelihood)

print("Compiling the inference kernel...", end=" ")
start = time.time()
print("Build and compile the inference kernel...")
kernel = build_kernel(loglikelihood, parameters)
kernel = jax.jit(kernel)
print("Compiled in {:.2f}s".format(time.time() - start))

self.kernel = kernel
self.state = state
Expand Down Expand Up @@ -91,16 +81,32 @@ def update_chains(state, rng_key):
return trace


def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs):
def generate(rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs):
""" The generator runtime """

initialize, build_kernel, to_trace = runtime
init, warmup, build_kernel, to_trace = program

loglikelihood, initial_state, parameters, unravel_fn = initialize(
rng_key, num_chains, **kwargs
print("Initialize the sampler\n")

validate_conditioning_variables(model, **kwargs)
loglikelihood = build_loglikelihood(model, **kwargs)

print("Find initial states...")
initial_position, unravel_fn = get_initial_position(
rng_key, model, num_chains, **kwargs
)
loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn)
initial_state = jax.vmap(init, in_axes=(0, None))(
initial_position, jax.value_and_grad(loglikelihood)
)

print("Warmup the chains...")
parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps)

print("Compile the log-likelihood...")
loglikelihood = jax.jit(loglikelihood)

print("Build and compile the inference kernel...")
kernel = build_kernel(loglikelihood, parameters)
kernel = jax.jit(kernel)

Expand All @@ -114,7 +120,7 @@ def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs):
yield new_states


def check_conditioning_variables(model, **kwargs):
def validate_conditioning_variables(model, **kwargs):
""" Check that all variables passed as arguments to the sampler
are random variables or arguments to the sampler. And converserly
that all of the model definition's positional arguments are given
Expand Down Expand Up @@ -165,17 +171,31 @@ def get_initial_position(rng_key, model, num_chains, **kwargs):
samples = sample_forward(rng_key, model, num_samples=num_chains, **kwargs)
initial_positions = dict((var, samples[var]) for var in to_sample_vars)

positions = []
for i in range(num_chains):
position = {k: value[i] for k, value in initial_positions.items()}
flat_position, unravel_fn = ravel_pytree(position)
positions.append(flat_position)
positions = np.stack(positions)
# A naive way to go about flattening the positions is to transform the
# dictionary of arrays that contain the parameter value to a list of
# dictionaries, one per position and then unravel the dictionaries.
# However, this approach takes more time than getting the samples in the
# first place.
#
# Luckily, JAX first sorts dictionaries by keys
# (https://github.com/google/jax/blob/master/jaxlib/pytree.cc) when
# raveling pytrees. We can thus ravel and stack parameter values in an
# array, sorting by key; this gives our flattened positions. We then build
# a single dictionary that contains the parameters value and use it to get
# the unraveling function using `unravel_pytree`.
positions = np.stack(
[np.ravel(samples[s]) for s in sorted(initial_positions.keys())], axis=1
)

sample_position_dict = {
parameter: values[0] for parameter, values in initial_positions.items()
}
_, unravel_fn = ravel_pytree(sample_position_dict)

return positions, unravel_fn


def _flatten_logpdf(logpdf, unravel_fn):
def flatten_loglikelihood(logpdf, unravel_fn):
def flattened_logpdf(array):
kwargs = unravel_fn(array)
return logpdf(**kwargs)
Expand Down
4 changes: 2 additions & 2 deletions mcx/hmc.py → mcx/inference/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def HMC(
mass_matrix_sqrt=None,
inverse_mass_matrix=None,
integrator=velocity_verlet,
is_mass_matrix_diagonal=True,
is_mass_matrix_diagonal=False,
):

parameters = HMCParameters(
Expand All @@ -31,7 +31,7 @@ def init(position, value_and_grad):
log_prob, log_prob_grad = value_and_grad(position)
return HMCState(position, log_prob, log_prob_grad)

def warmup(initial_state, logpdf):
def warmup(initial_state, logpdf, num_warmup_steps):
return parameters, initial_state

def build_kernel(logpdf, parameters):
Expand Down
5 changes: 2 additions & 3 deletions mcx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,8 @@ def sample_forward(rng_key, model: model, num_samples=1000, **kwargs) -> Dict:
sampler_fn = jax.jit(sampler_fn)
samples = jax.vmap(sampler_fn, in_axes=in_axes, out_axes=out_axes)(*sampler_args)

trace = {}
for arg, arg_samples in zip(model.variables, samples):
trace[arg] = numpy.asarray(arg_samples).T.squeeze()
trace = {arg: numpy.asarray(arg_samples).T.squeeze() for arg, arg_samples in zip(model.variables, samples)}

return trace


Expand Down