Skip to content

Commit 31e558b

Browse files
authored
Merge pull request #1 from DanWaxman/dw-dpf
Add Differentiable Particle Filter Support
2 parents 89b2e38 + 3adad9f commit 31e558b

File tree

11 files changed

+2998
-40
lines changed

11 files changed

+2998
-40
lines changed

cd_dynamax/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
EnKFHyperParams,
1414
)
1515

16+
from .src.continuous_discrete_nonlinear_ssm import (
17+
ContDiscreteNonlinearSSM,
18+
ParamsCDNLSSM,
19+
DPFHyperParams,
20+
cdnlssm_filter,
21+
)
22+
1623
# Linear SSM
1724
from .src.continuous_discrete_linear_gaussian_ssm import (
1825
ContDiscreteLinearGaussianSSM,
@@ -40,19 +47,23 @@
4047
__all__ = [
4148
# Models
4249
"ContDiscreteNonlinearGaussianSSM",
50+
"ContDiscreteNonlinearSSM",
4351
"ContDiscreteLinearGaussianSSM",
4452
"LinearGaussianSSM",
4553
# Params
4654
"ParamsCDNLGSSM",
55+
"ParamsCDNLSSM",
4756
"ParamsCDLGSSM",
4857
# Nonlinear algos
4958
"cdnlgssm_filter",
5059
"cdnlgssm_smoother",
5160
"cdnlgssm_forecast",
5261
"cdnlgssm_emissions",
62+
"cdnlssm_filter",
5363
"EKFHyperParams",
5464
"UKFHyperParams",
5565
"EnKFHyperParams",
66+
"DPFHyperParams",
5667
# Linear algos
5768
"cdlgssm_filter",
5869
"cdlgssm_smoother",

cd_dynamax/src/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
UKFHyperParams,
1010
EnKFHyperParams,
1111
)
12+
from .continuous_discrete_nonlinear_ssm import (
13+
ParamsCDNLSSM,
14+
ContDiscreteNonlinearSSM,
15+
DPFHyperParams,
16+
cdnlssm_filter,
17+
)
1218

1319
from .continuous_discrete_linear_gaussian_ssm import (
1420
ParamsCDLGSSM,
@@ -27,32 +33,33 @@
2733
__all__ = [
2834
### SSM classes ###
2935
"ContDiscreteNonlinearGaussianSSM",
36+
"ContDiscreteNonlinearSSM",
3037
"ContDiscreteLinearGaussianSSM",
31-
3238
### Param classes ###
3339
"ParamsCDNLGSSM",
40+
"ParamsCDNLSSM",
3441
"ParamsCDLGSSM",
35-
3642
### Non-linear filters/smoothers/forecasters ###
3743
"cdnlgssm_filter",
3844
"cdnlgssm_smoother",
3945
"cdnlgssm_forecast",
4046
"cdnlgssm_emissions",
41-
4247
# Non-linear filtering Hyperparams
4348
"EKFHyperParams",
4449
"UKFHyperParams",
4550
"EnKFHyperParams",
46-
51+
"DPFHyperParams",
52+
"cdnlssm_filter",
4753
### Linear filters/smoothers/forecasters/samplers ###
4854
"cdlgssm_filter",
4955
"cdlgssm_smoother",
5056
"cdlgssm_forecast",
5157
"cdlgssm_emissions",
5258
"cdlgssm_posterior_sample",
5359
"cdlgssm_joint_sample",
54-
5560
# Linear filtering Hyperparams (typically don't need to use these directly, can rely on default KFHyperParams)
5661
"KFHyperParams",
57-
62+
# Base interfaces
63+
"SSM",
64+
"Prior",
5865
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .cdnlssm_utils import ParamsCDNLSSM
2+
from .cdnlssm_utils import ParamsCDNLSSMDynamics, ParamsCDNLSSMEmissions
3+
4+
from .models import ContDiscreteNonlinearSSM
5+
from .models import cdnlssm_filter
6+
7+
from .inference_dpf import DPFHyperParams, filter_dpf
8+
9+
from .builders import build_params
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import equinox as eqx
2+
import inspect
3+
from jax.scipy.special import gammaln
4+
import jax.numpy as jnp
5+
import tensorflow_probability.substrates.jax.distributions as tfd
6+
7+
from .cdnlssm_utils import (
8+
ParamsCDNLSSM,
9+
ParamsCDNLSSMEmissions,
10+
ParamsCDNLSSMInitial,
11+
ParamsCDNLSSMDynamics,
12+
)
13+
14+
from ..continuous_discrete_nonlinear_gaussian_ssm.builders import (
15+
wrap_function_or_constant,
16+
_default,
17+
)
18+
19+
20+
# -------------------------
21+
# Distribution wrappers
22+
# -------------------------
23+
class StaticDistribution(eqx.Module):
24+
distribution: tfd.Distribution
25+
26+
def log_prob(self, y, x=None, u=None, t=None):
27+
return self.distribution.log_prob(y)
28+
29+
def sample(self, *args, **kwargs):
30+
return self.distribution.sample(*args, **kwargs)
31+
32+
33+
class ConditionalDistribution(eqx.Module):
34+
distribution_fn: callable
35+
36+
def _dist(self, x=None, u=None, t=None):
37+
sig = inspect.signature(self.distribution_fn)
38+
kwargs = {}
39+
if "x" in sig.parameters:
40+
kwargs["x"] = x
41+
if "u" in sig.parameters:
42+
kwargs["u"] = u
43+
if "t" in sig.parameters:
44+
kwargs["t"] = t
45+
return self.distribution_fn(**kwargs)
46+
47+
def log_prob(self, y, x=None, u=None, t=None):
48+
return self._dist(x, u, t).log_prob(y)
49+
50+
def sample(self, x=None, u=None, t=None, *args, **kwargs):
51+
return self._dist(x, u, t).sample(*args, **kwargs)
52+
53+
54+
class GaussianEmission(eqx.Module):
55+
emission_function: eqx.Module
56+
emission_cov: eqx.Module
57+
58+
def log_prob(self, y, x=None, u=None, t=None):
59+
mean = self.emission_function.f(x, u, t)
60+
cov = self.emission_cov.f(x, u, t)
61+
return tfd.MultivariateNormalFullCovariance(mean, cov).log_prob(y)
62+
63+
def sample(self, x=None, u=None, t=None, *args, **kwargs):
64+
mean = self.emission_function.f(x, u, t)
65+
cov = self.emission_cov.f(x, u, t)
66+
return tfd.MultivariateNormalFullCovariance(mean, cov).sample(*args, **kwargs)
67+
68+
69+
class PoissonEmission(eqx.Module):
70+
dt: float = eqx.field(static=True)
71+
bias: jnp.ndarray
72+
73+
def log_prob(self, y, x=None, u=None, t=None):
74+
log_rate = x[..., 0] + self.bias + jnp.log(self.dt)
75+
y0 = jnp.squeeze(jnp.asarray(y, dtype=log_rate.dtype))
76+
return y0 * log_rate - jnp.exp(log_rate) - gammaln(y0 + 1.0)
77+
78+
def sample(self, x=None, u=None, t=None, *args, **kwargs):
79+
log_rate = x[..., 0] + self.bias + jnp.log(self.dt)
80+
return tfd.Poisson(log_rate=log_rate).sample(*args, **kwargs)
81+
82+
83+
class TransformedDistribution(eqx.Module):
84+
base_distribution: eqx.Module
85+
transform: eqx.Module
86+
87+
def log_prob(self, y, x=None, u=None, t=None):
88+
return self.base_distribution.log_prob(self.transform.f(y, u, t), x, u, t)
89+
90+
def sample(self, x=None, u=None, t=None, *args, **kwargs):
91+
base_sample = self.base_distribution.sample(x, u, t)
92+
return self.transform.f(base_sample, u, t, *args, **kwargs)
93+
94+
95+
# -------------------------
96+
# Param construction
97+
# -------------------------
98+
def _build_initial_distribution(
99+
state_dim, initial_distribution, initial_mean, initial_cov
100+
):
101+
if initial_distribution is not None:
102+
if hasattr(initial_distribution, "distribution"):
103+
return initial_distribution
104+
if callable(initial_distribution):
105+
return StaticDistribution(distribution=initial_distribution())
106+
return StaticDistribution(distribution=initial_distribution)
107+
108+
# Default to a Gaussian if initial_distribution is None
109+
mean = wrap_function_or_constant(
110+
_default(initial_mean, jnp.zeros(state_dim)), expected_shape=(state_dim,)
111+
).f()
112+
cov = wrap_function_or_constant(
113+
_default(initial_cov, jnp.eye(state_dim)), expected_shape=(state_dim, state_dim)
114+
).f()
115+
return StaticDistribution(
116+
distribution=tfd.MultivariateNormalFullCovariance(mean, cov)
117+
)
118+
119+
120+
def _build_emission_distribution(
121+
emission_dim, emission_distribution, emission_function, emission_cov
122+
):
123+
if emission_distribution is not None:
124+
if isinstance(emission_distribution, tfd.Distribution):
125+
return StaticDistribution(distribution=emission_distribution)
126+
if hasattr(emission_distribution, "log_prob"):
127+
return emission_distribution
128+
return ConditionalDistribution(distribution_fn=emission_distribution)
129+
130+
emission_fn = wrap_function_or_constant(
131+
_default(emission_function, lambda x, u=None, t=None: x[..., :emission_dim])
132+
)
133+
emission_cov_fn = wrap_function_or_constant(
134+
_default(emission_cov, jnp.eye(emission_dim)),
135+
expected_shape=(emission_dim, emission_dim),
136+
)
137+
return GaussianEmission(
138+
emission_function=emission_fn,
139+
emission_cov=emission_cov_fn,
140+
)
141+
142+
143+
def build_params(
144+
state_dim,
145+
emission_dim,
146+
drift,
147+
initial_mean=None,
148+
initial_cov=None,
149+
emission_function=None,
150+
emission_cov=None,
151+
initial_distribution=None,
152+
emission_distribution=None,
153+
diffusion_coeff=None,
154+
diffusion_cov=None,
155+
approx_order: float = 1.0,
156+
):
157+
"""
158+
Build parameters for a CDNLSSM model mirroring the CDNLGSSM interface.
159+
Callable parameters should accept only (x, u, t) or a subset of these.
160+
"""
161+
return ParamsCDNLSSM(
162+
initial=ParamsCDNLSSMInitial(
163+
initial_distribution=_build_initial_distribution(
164+
state_dim, initial_distribution, initial_mean, initial_cov
165+
)
166+
),
167+
dynamics=ParamsCDNLSSMDynamics(
168+
drift=wrap_function_or_constant(drift),
169+
diffusion_coefficient=wrap_function_or_constant(
170+
_default(diffusion_coeff, jnp.eye(state_dim)),
171+
expected_shape=(state_dim, state_dim),
172+
),
173+
diffusion_cov=wrap_function_or_constant(
174+
_default(diffusion_cov, jnp.eye(state_dim)),
175+
expected_shape=(state_dim, state_dim),
176+
),
177+
approx_order=approx_order,
178+
),
179+
emissions=ParamsCDNLSSMEmissions(
180+
emission_distribution=_build_emission_distribution(
181+
emission_dim, emission_distribution, emission_function, emission_cov
182+
),
183+
),
184+
)

0 commit comments

Comments
 (0)