Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
66bbcbc
Rewrite NDVariableArray
antoine-dedieu Apr 12, 2022
d816f63
Falke8
antoine-dedieu Apr 12, 2022
58b0115
Test + HashableDict
antoine-dedieu Apr 12, 2022
c8f2c5e
Minbor
antoine-dedieu Apr 13, 2022
39b546a
Variables as tuple + Remove BPOuputs/HashableDict
antoine-dedieu Apr 20, 2022
ef92d1b
Start tests + mypy
antoine-dedieu Apr 21, 2022
7e55e5c
Tests passing
antoine-dedieu Apr 21, 2022
9a2dcd2
Variables
antoine-dedieu Apr 21, 2022
c6ae8d8
Tests + mypy
antoine-dedieu Apr 21, 2022
5c8b381
Some docstrings
antoine-dedieu Apr 22, 2022
40ec519
Stannis first comments
antoine-dedieu Apr 22, 2022
96c7fe3
Remove add_factor
antoine-dedieu Apr 22, 2022
88f8e23
Test
antoine-dedieu Apr 23, 2022
033d176
Docstring
antoine-dedieu Apr 23, 2022
89767c8
Coverage
antoine-dedieu Apr 25, 2022
1ccfcf5
Coverage
antoine-dedieu Apr 25, 2022
ecaab6c
Coverage 100%
antoine-dedieu Apr 25, 2022
cbd136b
Remove factor group names
antoine-dedieu Apr 25, 2022
22b604e
Remove factor group names
antoine-dedieu Apr 25, 2022
87fcfd9
Modify hash + add_factors
antoine-dedieu Apr 26, 2022
0f639fe
Stannis' comments
antoine-dedieu Apr 26, 2022
546b790
Flattent / unflattent
antoine-dedieu Apr 26, 2022
35ce6c0
Unflatten with nan
antoine-dedieu Apr 26, 2022
704ee3a
Speeding up
antoine-dedieu Apr 27, 2022
aaa67ef
max size
antoine-dedieu Apr 27, 2022
04c4d89
Understand timings
antoine-dedieu Apr 27, 2022
a276ce5
Some comments
antoine-dedieu Apr 27, 2022
a839be6
Comments
antoine-dedieu Apr 27, 2022
05de350
Minor
antoine-dedieu Apr 27, 2022
1e0c93d
Docstring
antoine-dedieu Apr 27, 2022
63c6738
Minor changes
antoine-dedieu Apr 28, 2022
0397f9b
Doc
antoine-dedieu Apr 28, 2022
8b3d60e
Rename this_hash
StannisZhou Apr 30, 2022
8751e90
Final comments
antoine-dedieu May 2, 2022
b265f14
Minor
antoine-dedieu May 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Variables as tuple + Remove BPOuputs/HashableDict
  • Loading branch information
antoine-dedieu committed Apr 20, 2022
commit 39b546a9fa88d8779e8b01ceea2858ec1c71761b
38 changes: 17 additions & 21 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@
variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
fg = graph.FactorGraph(variables=variables)

# TODO: rename variable_for_factors?
variable_names_for_factors = []
variables_for_factors = []
for ii in range(50):
for jj in range(50):
kk = (ii + 1) % 50
ll = (jj + 1) % 50
variable_names_for_factors.append([variables[ii, jj], variables[kk, jj]])
variable_names_for_factors.append([variables[ii, jj], variables[kk, ll]])
variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
variables_for_factors.append([variables[ii, jj], variables[kk, ll]])

fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=variable_names_for_factors,
variables_for_factors=variables_for_factors,
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
name="factors",
)

# %% [markdown]
Expand All @@ -52,22 +52,17 @@
# %%
bp = graph.BP(fg.bp_state, temperature=0)

# %%
d = {variables: 1, variables.__hash__(): 2}
hd = graph.HashableDict(d)
hd[variables]

# %%
# TODO: check\ time for before BP vs time for BP
# TODO: time PGMAX vs PMP
bp_arrays = bp.init(
evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
)
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
output = bp.get_bp_output(bp_arrays)
beliefs = bp.get_beliefs(bp_arrays)

# %%
img = output.map_states[variables]
img = graph.decode_map_states(beliefs)[variables]
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)

Expand All @@ -82,19 +77,20 @@ def loss(log_potentials_updates, evidence_updates):
)
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
beliefs = bp.get_beliefs(bp_arrays)
loss = -jnp.sum(beliefs)
loss = -jnp.sum(beliefs[variables])
return loss


batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {None: 0}), out_axes=0))
batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {variables: 0}), out_axes=0))
log_potentials_grads = jax.jit(jax.grad(loss, argnums=0))

# %%
batch_loss(None, {None: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})
batch_loss(None, {variables: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})

# %%
grads = log_potentials_grads(
{"factors": jnp.eye(2)}, {None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
{"factors": jnp.eye(2)},
{variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))},
)

# %% [markdown]
Expand All @@ -104,15 +100,15 @@ def loss(log_potentials_updates, evidence_updates):
bp_state = bp.to_bp_state(bp_arrays)

# Query evidence for variable (0, 0)
bp_state.evidence[0, 0]
bp_state.evidence[variables[0, 0]]

# %%
# Set evidence for variable (0, 0)
bp_state.evidence[0, 0] = np.array([1.0, 1.0])
bp_state.evidence[0, 0]
bp_state.evidence[variables[0, 0]] = np.array([1.0, 1.0])
bp_state.evidence[variables[0, 0]]

# %%
# Set evidence for all variables using an array
evidence = np.random.randn(50, 50, 2)
bp_state.evidence[None] = evidence
bp_state.evidence[10, 10] == evidence[10, 10]
bp_state.evidence[variables] = evidence
np.allclose(bp_state.evidence[variables[10, 10]], evidence[10, 10])
67 changes: 18 additions & 49 deletions examples/pmp_binary_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def plot_images(images, display=True, nr=None):
# - a second set of ORFactors, which maps SW to X and model (binary) features overlapping.
#
# See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details.
#
# import imp
#
# import imp

import imp

Expand Down Expand Up @@ -157,8 +153,8 @@ def plot_images(images, display=True, nr=None):
print(time.time() - start)

# Define the ANDFactors
variable_names_for_ANDFactors = []
variable_names_for_ORFactors_dict = defaultdict(list)
variables_for_ANDFactors = []
variables_for_ORFactors_dict = defaultdict(list)
for idx_img in tqdm(range(n_images)):
for idx_chan in range(n_chan):
for idx_s_height in range(s_height):
Expand All @@ -178,36 +174,34 @@ def plot_images(images, display=True, nr=None):
idx_feat_width,
]

variable_names_for_ANDFactor = [
variables_for_ANDFactor = [
S[idx_img, idx_feat, idx_s_height, idx_s_width],
W[idx_chan, idx_feat, idx_feat_height, idx_feat_width],
SW_var,
]
variable_names_for_ANDFactors.append(
variable_names_for_ANDFactor
)
variables_for_ANDFactors.append(variables_for_ANDFactor)

X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
variable_names_for_ORFactors_dict[X_var].append(SW_var)
variables_for_ORFactors_dict[X_var].append(SW_var)
print(time.time() - start)

# Add ANDFactorGroup, which is computationally efficient
fg.add_factor_group(
factory=logical.ANDFactorGroup,
variable_names_for_factors=variable_names_for_ANDFactors,
variables_for_factors=variables_for_ANDFactors,
)
print(time.time() - start)

# Define the ORFactors
variable_names_for_ORFactors = [
list(tuple(variable_names_for_ORFactors_dict[X_var]) + (X_var,))
for X_var in variable_names_for_ORFactors_dict
variables_for_ORFactors = [
list(tuple(variables_for_ORFactors_dict[X_var]) + (X_var,))
for X_var in variables_for_ORFactors_dict
]

# Add ORFactorGroup, which is computationally efficient
fg.add_factor_group(
factory=logical.ORFactorGroup,
variable_names_for_factors=variable_names_for_ORFactors,
variables_for_factors=variables_for_ORFactors,
)
print("Time", time.time() - start)

Expand Down Expand Up @@ -236,7 +230,7 @@ def plot_images(images, display=True, nr=None):

# %%
pW = 0.25
pS = 1e-70
pS = 1e-72
pX = 1e-100

# Sparsity inducing priors for W and S
Expand All @@ -250,31 +244,6 @@ def plot_images(images, display=True, nr=None):
uX = np.zeros((X_gt.shape) + (2,))
uX[..., 0] = (2 * X_gt - 1) * logit(pX)

# %%
np.random.seed(seed=40)

start = time.time()
bp_arrays = bp.init(
evidence_updates={
S: uS + np.random.gumbel(size=uS.shape),
W: uW + np.random.gumbel(size=uW.shape),
SW: np.zeros(SW.shape),
X: uX + np.zeros(shape=uX.shape),
},
)
print("Time", time.time() - start)
bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
print("Time", time.time() - start)
output = bp.get_bp_output(bp_arrays)
print("Time", time.time() - start)

# %%
_ = plot_images(output.map_states[W].reshape(-1, feat_height, feat_width), nr=1)

# %%

# %%

# %% [markdown]
# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`

Expand All @@ -285,10 +254,10 @@ def plot_images(images, display=True, nr=None):
start = time.time()
bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
evidence_updates={
"S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
"W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
"SW": np.zeros(shape=(n_samples,) + SW.shape),
"X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
S: uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
W: uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
SW: np.zeros(shape=(n_samples,) + SW.shape),
X: uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
},
)
print("Time", time.time() - start)
Expand All @@ -298,13 +267,13 @@ def plot_images(images, display=True, nr=None):
out_axes=0,
)(bp_arrays)
print("Time", time.time() - start)
# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
# map_states = graph.decode_map_states(beliefs)
beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)

# %% [markdown]
# Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior!
#
# Because we have used one extra feature for inference, each posterior sample recovers the 4 basic features used to generate the images, and includes an extra symbol.

# %%
_ = plot_images(map_states["W"].reshape(-1, feat_height, feat_width), nr=n_samples)
_ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples)
29 changes: 12 additions & 17 deletions examples/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@

# %% [markdown]
# We can then initialize the factor graph for the RBM with
#
# import imp

import imp

Expand Down Expand Up @@ -116,14 +114,14 @@ def run_numba(h, v, f):
# Add unary factors
fg.add_factor_group(
factory=enumeration.EnumerationFactorGroup,
variable_names_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
factor_configs=np.arange(2)[:, None],
log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
)

fg.add_factor_group(
factory=enumeration.EnumerationFactorGroup,
variable_names_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
factor_configs=np.arange(2)[:, None],
log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
)
Expand All @@ -133,7 +131,7 @@ def run_numba(h, v, f):

fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
variables_for_factors=[
[hidden_variables[ii], visible_variables[jj]]
for ii in range(bh.shape[0])
for jj in range(bv.shape[0])
Expand Down Expand Up @@ -218,7 +216,7 @@ def run_numba(h, v, f):
print("Time", time.time() - start)
bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
print("Time", time.time() - start)
output = bp.get_bp_output(bp_arrays)
beliefs = bp.get_beliefs(bp_arrays)
print("Time", time.time() - start)

# %% [markdown]
Expand All @@ -228,7 +226,9 @@ def run_numba(h, v, f):

# %%
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(output.map_states[visible_variables].copy().reshape((28, 28)), cmap="gray")
ax.imshow(
graph.map_states(beliefs)[visible_variables].copy().reshape((28, 28)), cmap="gray"
)
ax.axis("off")

# %% [markdown]
Expand Down Expand Up @@ -256,8 +256,8 @@ def run_numba(h, v, f):
n_samples = 10
bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
evidence_updates={
"hidden": np.random.gumbel(size=(n_samples, bh.shape[0], 2)),
"visible": np.random.gumbel(size=(n_samples, bv.shape[0], 2)),
hidden_variables: np.random.gumbel(size=(n_samples, bh.shape[0], 2)),
visible_variables: np.random.gumbel(size=(n_samples, bv.shape[0], 2)),
},
)
bp_arrays = jax.vmap(
Expand All @@ -266,11 +266,8 @@ def run_numba(h, v, f):
out_axes=0,
)(bp_arrays)

# TODO: problem
outputs = jax.vmap(bp.get_bp_output, in_axes=0, out_axes=0)(bp_arrays)
# map_states = graph.decode_map_states(beliefs)

# %%
beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)

# %% [markdown]
# Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel!
Expand All @@ -279,10 +276,8 @@ def run_numba(h, v, f):
fig, ax = plt.subplots(2, 5, figsize=(20, 8))
for ii in range(10):
ax[np.unravel_index(ii, (2, 5))].imshow(
map_states["visible"][ii].copy().reshape((28, 28)), cmap="gray"
map_states[visible_variables][ii].copy().reshape((28, 28)), cmap="gray"
)
ax[np.unravel_index(ii, (2, 5))].axis("off")

fig.tight_layout()

# %%
Loading