Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
66bbcbc
Rewrite NDVariableArray
antoine-dedieu Apr 12, 2022
d816f63
Falke8
antoine-dedieu Apr 12, 2022
58b0115
Test + HashableDict
antoine-dedieu Apr 12, 2022
c8f2c5e
Minbor
antoine-dedieu Apr 13, 2022
39b546a
Variables as tuple + Remove BPOuputs/HashableDict
antoine-dedieu Apr 20, 2022
ef92d1b
Start tests + mypy
antoine-dedieu Apr 21, 2022
7e55e5c
Tests passing
antoine-dedieu Apr 21, 2022
9a2dcd2
Variables
antoine-dedieu Apr 21, 2022
c6ae8d8
Tests + mypy
antoine-dedieu Apr 21, 2022
5c8b381
Some docstrings
antoine-dedieu Apr 22, 2022
40ec519
Stannis first comments
antoine-dedieu Apr 22, 2022
96c7fe3
Remove add_factor
antoine-dedieu Apr 22, 2022
88f8e23
Test
antoine-dedieu Apr 23, 2022
033d176
Docstring
antoine-dedieu Apr 23, 2022
89767c8
Coverage
antoine-dedieu Apr 25, 2022
1ccfcf5
Coverage
antoine-dedieu Apr 25, 2022
ecaab6c
Coverage 100%
antoine-dedieu Apr 25, 2022
cbd136b
Remove factor group names
antoine-dedieu Apr 25, 2022
22b604e
Remove factor group names
antoine-dedieu Apr 25, 2022
87fcfd9
Modify hash + add_factors
antoine-dedieu Apr 26, 2022
0f639fe
Stannis' comments
antoine-dedieu Apr 26, 2022
546b790
Flattent / unflattent
antoine-dedieu Apr 26, 2022
35ce6c0
Unflatten with nan
antoine-dedieu Apr 26, 2022
704ee3a
Speeding up
antoine-dedieu Apr 27, 2022
aaa67ef
max size
antoine-dedieu Apr 27, 2022
04c4d89
Understand timings
antoine-dedieu Apr 27, 2022
a276ce5
Some comments
antoine-dedieu Apr 27, 2022
a839be6
Comments
antoine-dedieu Apr 27, 2022
05de350
Minor
antoine-dedieu Apr 27, 2022
1e0c93d
Docstring
antoine-dedieu Apr 27, 2022
63c6738
Minor changes
antoine-dedieu Apr 28, 2022
0397f9b
Doc
antoine-dedieu Apr 28, 2022
8b3d60e
Rename this_hash
StannisZhou Apr 30, 2022
8751e90
Final comments
antoine-dedieu May 2, 2022
b265f14
Minor
antoine-dedieu May 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Rewrite NDVariableArray
  • Loading branch information
antoine-dedieu committed Apr 12, 2022
commit 66bbcbcd47d83d34278866ce2deae870838c7f39
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ repos:
args: ['--config', '.flake8.config','--exit-zero']
verbose: true

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.942' # Use the sha / tag you want to point at
hooks:
- id: mypy
additional_dependencies: [tokenize-rt==3.2.0]
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: 'v0.942' # Use the sha / tag you want to point at
# hooks:
# - id: mypy
# additional_dependencies: [tokenize-rt==3.2.0]
ci:
autoupdate_schedule: 'quarterly'
25 changes: 20 additions & 5 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,50 @@
# %%
variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
fg = graph.FactorGraph(variables=variables)

# TODO: rename variable_for_factors?
variable_names_for_factors = []
for ii in range(50):
for jj in range(50):
kk = (ii + 1) % 50
ll = (jj + 1) % 50
variable_names_for_factors.append([(ii, jj), (kk, jj)])
variable_names_for_factors.append([(ii, jj), (ii, ll)])
variable_names_for_factors.append([variables[ii, jj], variables[kk, jj]])
variable_names_for_factors.append([variables[ii, jj], variables[kk, ll]])

fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=variable_names_for_factors,
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
name="factors",
)

# %%
from pgmax.factors import enumeration as enumeration_factor

factors = fg.factor_groups[enumeration_factor.EnumerationFactor][0].factors

# %% [markdown]
# ### Run inference and visualize results

import imp

# %%
from pgmax.fg import graph

imp.reload(graph)
bp = graph.BP(fg.bp_state, temperature=0)

# %%
# TODO: check\ time for before BP vs time for BP
# TODO: why bug when done twice?
# TODO: time PGMAX vs PMP
bp_arrays = bp.init(
evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
)
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
output = bp.get_bp_output(bp_arrays)

# %%
img = graph.decode_map_states(bp.get_beliefs(bp_arrays))
img = output.map_states[variables]
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)

Expand Down
61 changes: 47 additions & 14 deletions examples/pmp_binary_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,13 @@ def plot_images(images, display=True, nr=None):
#
# See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details.

import imp

# %%
from pgmax.fg import graph

imp.reload(graph)

# The dimensions of W used for the generation of X were (4, 5, 5) but we set them to (5, 6, 6)
# to simulate a more realistic scenario in which we do not know their ground truth values
n_feat, feat_height, feat_width = 5, 6, 6
Expand All @@ -116,6 +122,9 @@ def plot_images(images, display=True, nr=None):
s_height = im_height - feat_height + 1
s_width = im_width - feat_width + 1

import time

start = time.time()
# Binary features
W = vgroup.NDVariableArray(
num_states=2, shape=(n_chan, n_feat, feat_height, feat_width)
Expand All @@ -132,13 +141,16 @@ def plot_images(images, display=True, nr=None):

# Binary images obtained by convolution
X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)
print("Time", time.time() - start)

# %% [markdown]
# For computation efficiency, we add large FactorGroups via `fg.add_factor_group` instead of adding individual Factors

# %%
start = time.time()
# Factor graph
fg = graph.FactorGraph(variables=dict(S=S, W=W, SW=SW, X=X))
fg = graph.FactorGraph(variables=[S, W, SW, X])
print("x", time.time() - start)

# Define the ANDFactors
variable_names_for_ANDFactors = []
Expand All @@ -152,44 +164,39 @@ def plot_images(images, display=True, nr=None):
for idx_feat_width in range(feat_width):
idx_img_height = idx_feat_height + idx_s_height
idx_img_width = idx_feat_width + idx_s_width
SW_var = (
"SW",
SW_var = SW[
idx_img,
idx_chan,
idx_img_height,
idx_img_width,
idx_feat,
idx_feat_height,
idx_feat_width,
)
]

variable_names_for_ANDFactor = [
("S", idx_img, idx_feat, idx_s_height, idx_s_width),
(
"W",
idx_chan,
idx_feat,
idx_feat_height,
idx_feat_width,
),
S[idx_img, idx_feat, idx_s_height, idx_s_width],
W[idx_chan, idx_feat, idx_feat_height, idx_feat_width],
SW_var,
]
variable_names_for_ANDFactors.append(
variable_names_for_ANDFactor
)

X_var = (idx_img, idx_chan, idx_img_height, idx_img_width)
X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
variable_names_for_ORFactors_dict[X_var].append(SW_var)
print(time.time() - start)

# Add ANDFactorGroup, which is computationally efficient
fg.add_factor_group(
factory=logical.ANDFactorGroup,
variable_names_for_factors=variable_names_for_ANDFactors,
)
print(time.time() - start)

# Define the ORFactors
variable_names_for_ORFactors = [
list(tuple(variable_names_for_ORFactors_dict[X_var]) + (("X",) + X_var,))
list(tuple(variable_names_for_ORFactors_dict[X_var]) + (X_var,))
for X_var in variable_names_for_ORFactors_dict
]

Expand All @@ -198,6 +205,7 @@ def plot_images(images, display=True, nr=None):
factory=logical.ORFactorGroup,
variable_names_for_factors=variable_names_for_ORFactors,
)
print("Time", time.time() - start)

for factor_type, factor_groups in fg.factor_groups.items():
if len(factor_groups) > 0:
Expand All @@ -215,7 +223,9 @@ def plot_images(images, display=True, nr=None):
# in the same manner does not change X, so this naturally results in multiple equivalent modes.

# %%
start = time.time()
bp = graph.BP(fg.bp_state, temperature=0.0)
print("Time", time.time() - start)

# %% [markdown]
# We first compute the evidence without perturbation, similar to the PMP paper.
Expand All @@ -236,6 +246,29 @@ def plot_images(images, display=True, nr=None):
uX = np.zeros((X_gt.shape) + (2,))
uX[..., 0] = (2 * X_gt - 1) * logit(pX)

# %%
np.random.seed(seed=40)
n_samples = 1

start = time.time()
bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
evidence_updates={
"S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
"W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
"SW": np.zeros(shape=(n_samples,) + SW.shape),
"X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
},
)
print("Time", time.time() - start)
bp_arrays = jax.vmap(
functools.partial(bp.run_bp, num_iters=100, damping=0.5),
in_axes=0,
out_axes=0,
)(bp_arrays)
print("Time", time.time() - start)
# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
# map_states = graph.decode_map_states(beliefs)

# %% [markdown]
# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`

Expand Down
47 changes: 34 additions & 13 deletions examples/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,40 @@
# %% [markdown]
# We can then initialize the factor graph for the RBM with

import imp

# %%
from pgmax.fg import graph

imp.reload(graph)

import time

start = time.time()
# Initialize factor graph
hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape)
visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape)
fg = graph.FactorGraph(
variables=dict(hidden=hidden_variables, visible=visible_variables),
)
fg = graph.FactorGraph(variables=[hidden_variables, visible_variables])
print("Time", time.time() - start)

# %% [markdown]
# [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
#
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using

# %%
start = time.time()
# Add unary factors
fg.add_factor_group(
factory=enumeration.EnumerationFactorGroup,
variable_names_for_factors=[[("hidden", ii)] for ii in range(bh.shape[0])],
variable_names_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
factor_configs=np.arange(2)[:, None],
log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
)

fg.add_factor_group(
factory=enumeration.EnumerationFactorGroup,
variable_names_for_factors=[[("visible", jj)] for jj in range(bv.shape[0])],
variable_names_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
factor_configs=np.arange(2)[:, None],
log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
)
Expand All @@ -81,13 +90,16 @@
fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[("hidden", ii), ("visible", jj)]
[hidden_variables[ii], visible_variables[jj]]
for ii in range(bh.shape[0])
for jj in range(bv.shape[0])
],
log_potential_matrix=log_potential_matrix,
)

# fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=[[hidden_variables[ii], visible_variables[jj]]for ii in range(bh.shape[0])for jj in range(bv.shape[0])], log_potential_matrix=log_potential_matrix,)
print("Time", time.time() - start)


# %% [markdown]
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing Groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup), two [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s implemented in the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module.
Expand Down Expand Up @@ -146,18 +158,23 @@
# Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model

# %%
start = time.time()
bp = graph.BP(fg.bp_state, temperature=0.0)
print("Time", time.time() - start)

# %%
start = time.time()
bp_arrays = bp.init(
evidence_updates={
"hidden": np.random.gumbel(size=(bh.shape[0], 2)),
"visible": np.random.gumbel(size=(bv.shape[0], 2)),
hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)),
visible_variables: np.random.gumbel(size=(bv.shape[0], 2)),
},
)
print("Time", time.time() - start)
bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
beliefs = bp.get_beliefs(bp_arrays)
map_states = graph.decode_map_states(beliefs)
print("Time", time.time() - start)
output = bp.get_bp_output(bp_arrays)
print("Time", time.time() - start)

# %% [markdown]
# Here we use the `evidence_updates` argument of `run_bp` to perturb the model with Gumbel unary potentials. In general, `evidence_updates` can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference.
Expand All @@ -166,7 +183,7 @@

# %%
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(map_states["visible"].copy().reshape((28, 28)), cmap="gray")
ax.imshow(output.map_states[visible_variables].copy().reshape((28, 28)), cmap="gray")
ax.axis("off")

# %% [markdown]
Expand Down Expand Up @@ -203,8 +220,12 @@
in_axes=0,
out_axes=0,
)(bp_arrays)
beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)

outputs = jax.vmap(bp.get_bp_output, in_axes=0, out_axes=0)(bp_arrays)
# map_states = graph.decode_map_states(beliefs)

# %%
O

# %% [markdown]
# Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel!
Expand Down
20 changes: 10 additions & 10 deletions pgmax/factors/enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def __post_init__(self):
f"EnumerationFactor. Got a factor_configs array of shape {self.factor_configs.shape}."
)

if len(self.variables) != self.factor_configs.shape[1]:
if len(self.vars_to_num_states.keys()) != self.factor_configs.shape[1]:
raise ValueError(
f"Number of variables {len(self.variables)} doesn't match given configurations {self.factor_configs.shape}"
f"Number of variables {len(self.vars_to_num_states.keys())} doesn't match given configurations {self.factor_configs.shape}"
)

if self.log_potentials.shape != (self.factor_configs.shape[0],):
Expand All @@ -93,7 +93,7 @@ def __post_init__(self):
f"shape {self.log_potentials.shape}."
)

vars_num_states = np.array([variable.num_states for variable in self.variables])
vars_num_states = np.array([list(self.vars_to_num_states.values())])
if not np.logical_and(
self.factor_configs >= 0, self.factor_configs < vars_num_states[None]
).all():
Expand Down Expand Up @@ -155,9 +155,9 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
@staticmethod
def compile_wiring(
factor_edges_num_states: np.ndarray,
variables_for_factors: Tuple[nodes.Variable, ...],
variables_for_factors: Tuple[int, ...], # TODO: rename
factor_configs: np.ndarray,
vars_to_starts: Mapping[nodes.Variable, int],
vars_to_starts: Mapping[int, int],
num_factors: int,
) -> EnumerationWiring:
"""Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
Expand All @@ -166,7 +166,7 @@ def compile_wiring(
Args:
factor_edges_num_states: An array concatenating the number of states for the variables connected to each
Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup.
variables_for_factors: A tuple containing variables connected to each Factor of the FactorGroup.
Each variable will appear once for each Factor it connects to.
factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
of all valid configurations.
Expand All @@ -184,10 +184,10 @@ def compile_wiring(
var_states = np.array(
[vars_to_starts[variable] for variable in variables_for_factors]
)
num_states = np.array(
[variable.num_states for variable in variables_for_factors]
)
num_states_cumsum = np.insert(np.cumsum(num_states), 0, 0)
# num_states = np.array(
# [variable.num_states for variable in variables_for_factors]
# )
num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0)
var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int)
_compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states)

Expand Down
Loading