Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
66bbcbc
Rewrite NDVariableArray
antoine-dedieu Apr 12, 2022
d816f63
Falke8
antoine-dedieu Apr 12, 2022
58b0115
Test + HashableDict
antoine-dedieu Apr 12, 2022
c8f2c5e
Minbor
antoine-dedieu Apr 13, 2022
39b546a
Variables as tuple + Remove BPOuputs/HashableDict
antoine-dedieu Apr 20, 2022
ef92d1b
Start tests + mypy
antoine-dedieu Apr 21, 2022
7e55e5c
Tests passing
antoine-dedieu Apr 21, 2022
9a2dcd2
Variables
antoine-dedieu Apr 21, 2022
c6ae8d8
Tests + mypy
antoine-dedieu Apr 21, 2022
5c8b381
Some docstrings
antoine-dedieu Apr 22, 2022
40ec519
Stannis first comments
antoine-dedieu Apr 22, 2022
96c7fe3
Remove add_factor
antoine-dedieu Apr 22, 2022
88f8e23
Test
antoine-dedieu Apr 23, 2022
033d176
Docstring
antoine-dedieu Apr 23, 2022
89767c8
Coverage
antoine-dedieu Apr 25, 2022
1ccfcf5
Coverage
antoine-dedieu Apr 25, 2022
ecaab6c
Coverage 100%
antoine-dedieu Apr 25, 2022
cbd136b
Remove factor group names
antoine-dedieu Apr 25, 2022
22b604e
Remove factor group names
antoine-dedieu Apr 25, 2022
87fcfd9
Modify hash + add_factors
antoine-dedieu Apr 26, 2022
0f639fe
Stannis' comments
antoine-dedieu Apr 26, 2022
546b790
Flattent / unflattent
antoine-dedieu Apr 26, 2022
35ce6c0
Unflatten with nan
antoine-dedieu Apr 26, 2022
704ee3a
Speeding up
antoine-dedieu Apr 27, 2022
aaa67ef
max size
antoine-dedieu Apr 27, 2022
04c4d89
Understand timings
antoine-dedieu Apr 27, 2022
a276ce5
Some comments
antoine-dedieu Apr 27, 2022
a839be6
Comments
antoine-dedieu Apr 27, 2022
05de350
Minor
antoine-dedieu Apr 27, 2022
1e0c93d
Docstring
antoine-dedieu Apr 27, 2022
63c6738
Minor changes
antoine-dedieu Apr 28, 2022
0397f9b
Doc
antoine-dedieu Apr 28, 2022
8b3d60e
Rename this_hash
StannisZhou Apr 30, 2022
8751e90
Final comments
antoine-dedieu May 2, 2022
b265f14
Minor
antoine-dedieu May 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Stannis' comments
  • Loading branch information
antoine-dedieu committed Apr 26, 2022
commit 0f639fe6fcce3422c970e9aeea64a31a100b558c
8 changes: 0 additions & 8 deletions examples/gmrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,3 @@ def update(step, batch_noisy_images, batch_target_images, opt_state):
)
pbar.update()
pbar.set_postfix(loss=value)

batch_indices = indices[idx * batch_size : (idx + 1) * batch_size]
batch_noisy_images, batch_target_images = (
noisy_images_train[:10],
target_images_train[:10],
)
step = 0
value, opt_state = update(step, batch_noisy_images, batch_target_images, opt_state)
4 changes: 2 additions & 2 deletions pgmax/factors/enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
from dataclasses import dataclass
from typing import List, Mapping, Sequence, Tuple, Union
from typing import Any, List, Mapping, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -156,7 +156,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
def compile_wiring(
variables_for_factors: Sequence[List],
factor_configs: np.ndarray,
vars_to_starts: Mapping[Tuple[int, int], int],
vars_to_starts: Mapping[Tuple[Any, int], int],
num_factors: int,
) -> EnumerationWiring:
"""Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
Expand Down
4 changes: 2 additions & 2 deletions pgmax/factors/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
from dataclasses import dataclass, field
from typing import List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -141,7 +141,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
@staticmethod
def compile_wiring(
variables_for_factors: Sequence[List],
vars_to_starts: Mapping[Tuple[int, int], int],
vars_to_starts: Mapping[Tuple[Any, int], int],
edge_states_offset: int,
) -> LogicalWiring:
"""Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors.
Expand Down
66 changes: 17 additions & 49 deletions pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import copy
import functools
import inspect
import typing
from dataclasses import asdict, dataclass
from types import MappingProxyType
from typing import (
Expand Down Expand Up @@ -44,12 +43,7 @@ class FactorGraph:
Factors in a graph are clustered in factor groups, which are grouped according to their factor types.

Args:
variables: A single VariableGroup or a container containing variable groups.
If not a single VariableGroup, supported containers include mapping and sequence.
For a mapping, the keys of the mapping are used to index the variable groups.
For a sequence, the indices of the sequence are used to index the variable groups.
Note that if not a single VariableGroup, a CompositeVariableGroup will be created from
this input, and the individual VariableGroups will need to be accessed by indexing.
variable_groups: A single VariableGroup or a list of VariableGroups.
"""

variable_groups: Union[groups.VariableGroup, Sequence[groups.VariableGroup]]
Expand Down Expand Up @@ -103,7 +97,7 @@ def add_factors(

Args:
factor_group: The FactorGroup to be added to the FactorGraph.
factor: The Factor to be added to the factor graph.
factor: The Factor to be added to the FactorGraph.

Raises:
ValueError: If
Expand Down Expand Up @@ -318,13 +312,13 @@ class FactorGraphState:
"""FactorGraphState.

Args:
variable_groups: All the variable groups in the FactorGraph.
variable_groups: VariableGroups in the FactorGraph.
vars_to_starts: Maps variables to their starting indices in the flat evidence array.
flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
contains evidence to the variable.
num_var_states: Total number of variable states.
total_factor_num_states: Size of the flat ftov messages array.
factor_groups: Factor groups in the FactorGraph
factor_groups: FactorGroups in the FactorGraph
factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
Expand All @@ -333,7 +327,7 @@ class FactorGraphState:
"""

variable_groups: Sequence[groups.VariableGroup]
vars_to_starts: Mapping[Tuple[int, int], int]
vars_to_starts: Mapping[Tuple[Any, int], int]
num_var_states: int
total_factor_num_states: int
factor_groups: Tuple[groups.FactorGroup, ...]
Expand Down Expand Up @@ -474,12 +468,11 @@ def __setitem__(
factor_group: Any,
data: Union[np.ndarray, jnp.ndarray],
):
"""Set the log potentials for a named factor group or a factor.
"""Set the log potentials for a FactorGroup

Args:
factor_group: FactorGroup
data: Array containing the log potentials for the named factor group
or the factor.
data: Array containing the log potentials for the FactorGroup
"""
object.__setattr__(
self,
Expand Down Expand Up @@ -510,7 +503,7 @@ def update_ftov_msgs(

Raises: ValueError if:
(1) provided ftov_msgs shape does not match the expected ftov_msgs shape.
(2) provided name is not valid for ftov_msgs updates.
(2) provided variable is not in the FactorGraph.
"""
for variable, data in updates.items():
if variable in fg_state.vars_to_starts:
Expand All @@ -535,14 +528,7 @@ def update_ftov_msgs(
data / starts.shape[0]
)
else:
raise ValueError(
"Invalid names for setting messages. "
"Supported names include a tuple of length 2 with factor "
"and variable names for directly setting factor to variable "
"messages, or a valid variable name for spreading expected "
"beliefs at a variable"
)

raise ValueError("Provided variable is not in the FactorGraph")
return ftov_msgs


Expand Down Expand Up @@ -574,38 +560,19 @@ def __post_init__(self):

object.__setattr__(self, "value", self.value)

@typing.overload
def __setitem__(
self,
names: Tuple[Any, Any],
data: Union[np.ndarray, jnp.ndarray],
) -> None:
"""Setting messages from a factor to a variable

Args:
names: A tuple of length 2
names[0] is the name of the factor
names[1] is the name of the variable
data: An array containing messages from factor names[0]
to variable names[1]
"""

@typing.overload
def __setitem__(
self,
variable: Tuple[int, int],
variable: Tuple[Any, int],
data: Union[np.ndarray, jnp.ndarray],
) -> None:
"""Spreading beliefs at a variable to all connected factors
"""Spreading beliefs at a variable to all connected Factors

Args:
variable: A tuple representing a variable
data: An array containing the beliefs to be spread uniformly
across all factor to variable messages involving this
variable.
across all factors to variable messages involving this variable.
"""

def __setitem__(self, variable, data) -> None:
object.__setattr__(
self,
"value",
Expand Down Expand Up @@ -647,7 +614,7 @@ def update_evidence(
evidence = evidence.at[start_index : start_index + name[1]].set(data)
else:
raise ValueError(
"Got evidence for a variable or a variable group not in the FactorGraph!"
"Got evidence for a variable or a VariableGroup not in the FactorGraph!"
)
return evidence

Expand Down Expand Up @@ -678,7 +645,7 @@ def __post_init__(self):

object.__setattr__(self, "value", self.value)

def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray:
def __getitem__(self, variable: Tuple[Any, int]) -> np.ndarray:
"""Function to query evidence for a variable

Args:
Expand Down Expand Up @@ -969,8 +936,9 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
beliefs = {}
start = 0
for variable_group in variable_groups:
variables = variable_group.variables
length = sum([variable[1] for variable in variables])
num_states = variable_group.num_states
assert isinstance(num_states, np.ndarray)
length = num_states.sum()

beliefs[variable_group] = variable_group.unflatten(
flat_beliefs[start : start + length]
Expand Down
9 changes: 5 additions & 4 deletions pgmax/fg/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import total_ordering
from typing import (
Any,
Collection,
FrozenSet,
List,
Mapping,
Expand All @@ -21,7 +22,7 @@
import pgmax.fg.nodes as nodes
from pgmax.utils import cached_property

MAX_SIZE = 1e16
MAX_SIZE = 1e10


@total_ordering
Expand All @@ -40,7 +41,7 @@ def __eq__(self, other):
def __lt__(self, other):
return hash(self) < hash(other)

def __getitem__(self, val):
def __getitem__(self, val: Any) -> Union[Tuple[Any, int], List[Tuple[Any, int]]]:
"""Given a variable name, index, or a group of variable indices, retrieve the associated variable(s).
Each variable is returned via a tuple of the form (variable hash/name, number of states)

Expand Down Expand Up @@ -129,7 +130,7 @@ def __eq__(self, other):
def __lt__(self, other):
return hash(self) < hash(other)

def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any:
def __getitem__(self, variables: Union[Sequence, Collection]) -> Any:
"""Function to query individual factors in the factor group

Args:
Expand Down Expand Up @@ -228,7 +229,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
"Please subclass the FactorGroup class and override this method"
)

def compile_wiring(self, vars_to_starts: Mapping[Tuple[int, int], int]) -> Any:
def compile_wiring(self, vars_to_starts: Mapping[Tuple[Any, int], int]) -> Any:
"""Compile an efficient wiring for the FactorGroup.

Args:
Expand Down
4 changes: 2 additions & 2 deletions pgmax/fg/nodes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A module containing classes that specify the basic components of a Factor Graph."""

from dataclasses import asdict, dataclass
from typing import List, Sequence, Tuple, Union
from typing import Any, List, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -48,7 +48,7 @@ class Factor:
NotImplementedError: If compile_wiring is not implemented
"""

variables: List[Tuple[int, int]]
variables: List[Tuple[Any, int]]
log_potentials: np.ndarray

def __post_init__(self):
Expand Down
7 changes: 3 additions & 4 deletions pgmax/groups/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass, field
from typing import FrozenSet, OrderedDict, Type

import numpy as np

from pgmax.factors import logical
from pgmax.fg import groups

Expand All @@ -20,12 +22,9 @@ class LogicalFactorGroup(groups.FactorGroup):
For ORFactors the edge_states_offset is 1, for ANDFactors the edge_states_offset is -1.
"""

factor_configs: np.ndarray = field(init=False, default=None)
edge_states_offset: int = field(init=False)

def __post_init__(self):
super().__post_init__()
object.__setattr__(self, "factor_configs", None)

def _get_variables_to_factors(
self,
) -> OrderedDict[FrozenSet, logical.LogicalFactor]:
Expand Down
22 changes: 12 additions & 10 deletions pgmax/groups/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class NDVariableArray(groups.VariableGroup):
num_states: Union[int, np.ndarray]

def __post_init__(self):
if np.prod(self.shape) > groups.MAX_SIZE:
if np.prod(self.shape) > int(groups.MAX_SIZE):
raise ValueError(
f"Currently only support NDVariableArray of size smaller than {groups.MAX_SIZE}. Got {np.prod(self.shape)}"
f"Currently only support NDVariableArray of size smaller than {int(groups.MAX_SIZE)}. Got {np.prod(self.shape)}"
)

if np.isscalar(self.num_states):
Expand All @@ -42,7 +42,9 @@ def __post_init__(self):
f"Expected num_states shape {self.shape}. Got {self.num_states.shape}."
)
else:
raise ValueError("num_states entries should be of type np.int")
raise ValueError(
"num_states should be an integer or a NumPy array of dtype int"
)

def __getitem__(
self, val: Union[int, slice, Tuple]
Expand Down Expand Up @@ -94,7 +96,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:

Args:
data: Meaningful structured data. Should be an array of shape self.shape (for e.g. MAP decodings)
or self.shape + (self.num_states,) (for e.g. evidence, beliefs).
or self.shape + (self.num_states.max(),) (for e.g. evidence, beliefs).

Returns:
A flat jnp.array for internal use
Expand Down Expand Up @@ -158,7 +160,8 @@ class VariableDict(groups.VariableGroup):

Args:
num_states: The size of the variables in this variable group
variable_names: A tuple of all names of the variables in this variable group
variable_names: A tuple of all names of the variables in this variable group.
Note that we overwrite variable_names to add the hash of the VariableDict
"""

variable_names: Tuple[Any, ...]
Expand All @@ -176,7 +179,7 @@ def __post_init__(self):
object.__setattr__(self, "variable_names", hash_and_names)

@cached_property
def variables(self) -> List[Tuple]:
def variables(self) -> List[Tuple[Tuple[Any, int], int]]:
"""Function that returns the list of all variables in the VariableGroup.
Each variable is represented by a tuple of the form (variable name, number of states)

Expand All @@ -188,24 +191,23 @@ def variables(self) -> List[Tuple]:
vars_num_states = self.num_states.flatten()
return list(zip(vars_names, vars_num_states))

def __getitem__(self, val: Any) -> Tuple[Any, int]:
def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]:
"""Given a variable name retrieve the associated variable, returned via a tuple of the form
(variable name, number of states)

Args:
val: a variable index or slice
val: a variable name

Returns:
The queried variable
"""
assert isinstance(self.num_states, np.ndarray)

if (self.__hash__(), val) not in self.variable_names:
raise ValueError(f"Variable {val} is not in VariableDict")
return ((self.__hash__(), val), self.num_states[0])

def flatten(
self, data: Mapping[Tuple[int, int], Union[np.ndarray, jnp.ndarray]]
self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]]
) -> jnp.ndarray:
"""Function that turns meaningful structured data into a flat data array for internal use.

Expand Down
Loading