Skip to content

Commit 3adad9f

Browse files
committed
Documentation + ESS-based resampling
1 parent 37f23b5 commit 3adad9f

File tree

5 files changed

+448
-243
lines changed

5 files changed

+448
-243
lines changed

cd_dynamax/src/continuous_discrete_nonlinear_ssm/cdnlssm_utils.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from cd_dynamax.dynamax.parameters import ParameterProperties, ParameterSet
77
import tensorflow_probability.substrates.jax.distributions as tfd
88

9-
109
from ..continuous_discrete_nonlinear_gaussian_ssm.cdnlgssm_utils import (
1110
_get_params,
1211
LearnableFunction,
1312
LearnableVector,
1413
LearnableMatrix,
14+
ParamsCDNLGSSMDynamics,
1515
)
1616

1717

@@ -67,12 +67,8 @@ def sample(self, x, u=None, t=None):
6767
return self.transform.f(base_sample, u, t)
6868

6969

70-
# Dynamics container for CD-NLSSM
71-
class ParamsCDNLSSMDynamics(NamedTuple):
72-
drift: LearnableFunction
73-
diffusion_coefficient: LearnableFunction
74-
diffusion_cov: LearnableFunction
75-
approx_order: Union[float, ParameterProperties]
70+
# Currently, we only support Brownian motion-driven SDEs, so we can reuse the CDNLGSSM dynamics parameters
71+
ParamsCDNLSSMDynamics = ParamsCDNLGSSMDynamics
7672

7773

7874
## CDNLSSM parameter class definitions
@@ -88,17 +84,16 @@ class ParamsCDNLSSMEmissions(NamedTuple):
8884

8985
# CDNLGSSM parameters are different to CDLGSSM due to nonlinearities
9086
class ParamsCDNLSSM(NamedTuple):
91-
r"""Parameters of a nonlinear Gaussian SSM.
87+
r"""Parameters of a continuous-discrete nonlinear SSM.
9288
9389
:param initial: initial distribution parameters
9490
:param dynamics: dynamics distribution parameters
9591
:param emissions: emission distribution parameters
9692
9793
The assumed transition and emission distributions are
98-
$$p(z_1) = N(z_1 | m, S)$$
99-
$$p(z_t | z_{t-1}, u_t) = N(z_t | m_t, P_t)$$
100-
$$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$
101-
94+
$$p(z_0) = p_initial(z_0)$$
95+
$$p(z_{t_k} | z_{t_{k-1}}, u_{t_k}) = solve_sde(z_{t_{k-1}}, u_{t_k}, t_{k-1}, t_k, f_dynamics, L_dynamics, Q_dynamics)$$
96+
$$p(y_{t_k} | z_{t_k}) = p_emissions(y_{t_k} | z_{t_k})$$
10297
"""
10398

10499
initial: ParamsCDNLSSMInitial

cd_dynamax/src/continuous_discrete_nonlinear_ssm/inference_dpf.py

Lines changed: 140 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import NamedTuple
1+
from typing import NamedTuple, Union, Tuple
22

33
from jaxtyping import Array, Float
44
from .cdnlssm_utils import ParamsCDNLSSM
@@ -14,12 +14,17 @@ class DPFHyperParams(NamedTuple):
1414

1515
dt_final: float = 1e-4
1616
N_particles: int = 100
17+
ess_threshold_ratio: float = 0.5
1718
resample_method: str = "stop_gradient"
1819
softness: float = 0.7
1920
cov_rescaling: float = 1.0
2021
state_order: str = "first"
2122
dt_average: float = 0.1
2223
diffeqsolve_settings: dict = {}
24+
return_ess_history: bool = False
25+
proposal_method: str = (
26+
"bootstrap" # Currently, only bootstrap proposals are supported.
27+
)
2328

2429

2530
def _predict(
@@ -31,7 +36,20 @@ def _predict(
3136
u,
3237
filter_hyperparams,
3338
):
34-
"""Predict evolution of ensemble of particles through the nonlinear stochastic dynamics."""
39+
"""Predict evolution of ensemble of particles through the nonlinear stochastic dynamics.
40+
41+
Args:
42+
key: Random key.
43+
x: Particles to predict.
44+
params: Parameters of the CDNLSSM.
45+
t0: Initial time.
46+
t1: Final time.
47+
u: Inputs.
48+
filter_hyperparams: Hyperparameters of the filter.
49+
50+
Returns:
51+
x_pred: Predicted particles.
52+
"""
3553

3654
def drift(t, y, args):
3755
return params.dynamics.drift.f(y, u, t)
@@ -95,11 +113,17 @@ def _normalize_log_weights(log_w):
95113
return log_w - log_norm, log_norm
96114

97115

116+
def _effective_sample_size(log_w):
117+
return jnp.exp(-logsumexp(2.0 * log_w))
118+
119+
98120
_SUPPORTED_RESAMPLERS = ("multinomial", "soft", "stop_gradient")
99121

100122

101123
def _validate_resample_method(method: str) -> str:
102-
"""Ensure the requested resampling strategy is implemented."""
124+
"""Ensure the requested resampling strategy is implemented.
125+
126+
See :attr:`_SUPPORTED_RESAMPLERS` for supported methods."""
103127
method_lower = method.lower()
104128
if method_lower not in _SUPPORTED_RESAMPLERS:
105129
raise ValueError(
@@ -110,6 +134,21 @@ def _validate_resample_method(method: str) -> str:
110134

111135

112136
def _multinomial_resample(key, x, log_w):
137+
"""Multinomial resampling.
138+
139+
This is the classical resampling strategy for particle filters,
140+
where each particle is resampled with probability proportional to its weight.
141+
142+
Args:
143+
key: Random key.
144+
x: Particles to resample.
145+
log_w: Log weights of the particles.
146+
147+
Returns:
148+
x_resampled: Resampled particles.
149+
log_w_resampled: Log weights of the resampled particles.
150+
idx: Indices of the resampled particles.
151+
"""
113152
probs = jnp.exp(log_w - logsumexp(log_w))
114153
idx = jr.choice(key, x.shape[0], shape=(x.shape[0],), p=probs, replace=True)
115154
x_resampled = x[idx]
@@ -118,7 +157,25 @@ def _multinomial_resample(key, x, log_w):
118157

119158

120159
def _stop_gradient_resample(key, x, log_w, *, base_method: str = "multinomial"):
121-
"""Resample while passing gradients through the chosen particles."""
160+
"""Stop-gradient resampling [1].
161+
162+
Stop-gradient resampling [1] modifies the resampling step such that forward passes are not modified.
163+
164+
References:
165+
[1] Scibior A, Wood F (2021). “Differentiable particle filtering without modifying the forward pass.” arXiv:2106.10314
166+
167+
Args:
168+
key: Random key.
169+
x: Particles to resample.
170+
log_w: Log weights of the particles.
171+
base_method: Base resampling method to use.
172+
173+
Returns:
174+
x_resampled: Resampled particles.
175+
log_w_resampled: Log weights of the resampled particles.
176+
idx: Indices of the resampled particles.
177+
178+
"""
122179
if base_method != "multinomial":
123180
raise ValueError(
124181
f"Unsupported base resampler '{base_method}' for stop_gradient resampling."
@@ -131,6 +188,27 @@ def _stop_gradient_resample(key, x, log_w, *, base_method: str = "multinomial"):
131188

132189

133190
def _soft_resample(key, x, log_w, softness):
191+
"""Soft resampling [1].
192+
193+
Soft resampling approximates differentiable resampling by mixing the weights w with a uniform distribution,
194+
q(k) = softness * w[k] + (1-softness) / n_particles,
195+
then reweighting via importance weights.
196+
197+
This strategy generally provides biased gradients, but can still be an efficient approximation.
198+
199+
References:
200+
[1] Karkus P, Hsu D, Lee WS (2018). “Particle filter networks with application to visual localization.” In Proc. Conf. Robot Learn., pp. 169–178. PMLR, Zurich, CH.
201+
202+
Args:
203+
key: Random key.
204+
x: Particles to resample.
205+
log_w: Log weights of the particles.
206+
softness: Softness parameter.
207+
208+
Returns:
209+
x_resampled: Resampled particles.
210+
log_w_resampled: Log weights of the resampled particles.
211+
"""
134212
n = x.shape[0]
135213
log_n = jnp.log(n)
136214
log_softness = jnp.log(softness)
@@ -151,11 +229,44 @@ def filter_dpf(
151229
us: Array | None = None,
152230
ts: Array | None = None,
153231
hyperparams: DPFHyperParams = DPFHyperParams(),
154-
):
155-
"""Differentiable particle filter with configurable resampling (default stop-gradient)."""
232+
) -> Union[Tuple[Array, Array, Array, float], Tuple[Array, Array, Array, Array, float]]:
233+
"""Differentiable particle filter with configurable resampling.
234+
235+
A differentiable particle filter (DPF) is a particle filter with the (discrete, non-differentiable) resampling step replaced in some way to allow for gradient-based optimization.
236+
This implementation supports three different resampling methods:
237+
- Multinomial resampling (biased)
238+
- Soft resampling [1] (biased; interpolates between multinomial and uniform resampling)
239+
- Stop-gradient resampling [2] (unbiased for score estimates)
240+
241+
Currently, only bootstrap proposals are supported.
242+
243+
References:
244+
[1] Karkus P, Hsu D, Lee WS (2018). “Particle filter networks with application to visual localization.” In Proc. Conf. Robot Learn., pp. 169–178. PMLR, Zurich, CH.
245+
[2] Scibior A, Wood F (2021). “Differentiable particle filtering without modifying the forward pass.” arXiv:2106.10314
246+
247+
Args:
248+
key: Random key.
249+
params: Parameters of the CDNLSSM.
250+
ys: Emissions.
251+
us: Inputs.
252+
ts: Times.
253+
hyperparams: Hyperparameters of the filter.
254+
255+
Returns:
256+
particles: Particles.
257+
log_weights: Log weights.
258+
ess_history: (if return_ess_history is True) Effective sample size history.
259+
log_evidence: Log evidence.
260+
"""
156261
n_particles = int(hyperparams.N_particles)
157262
T = ys.shape[0]
158263
resample_method = _validate_resample_method(hyperparams.resample_method)
264+
ess_threshold = hyperparams.ess_threshold_ratio * n_particles
265+
266+
if hyperparams.proposal_method != "bootstrap":
267+
raise ValueError(
268+
f"Currently, only bootstrap proposals are supported, but {hyperparams.proposal_method} was provided."
269+
)
159270

160271
key_init, key = jr.split(key)
161272
particles = params.initial.initial_distribution.distribution.sample(
@@ -200,28 +311,37 @@ def _do_predict(p):
200311
)
201312
log_w, log_norm = _normalize_log_weights(log_w)
202313
log_evidence = log_evidence + log_norm
314+
ess = _effective_sample_size(log_w)
203315

204316
particles_hist = particles
205317
logw_hist = log_w
206318

207-
# TODO: Only resample under criteria, e.g., ESS < threshold.
208-
if resample_method == "soft":
209-
particles, log_w = _soft_resample(
210-
key_resample, particles, log_w, hyperparams.softness
211-
)
212-
elif resample_method == "multinomial":
213-
particles, log_w, _ = _multinomial_resample(key_resample, particles, log_w)
214-
else: # stop_gradient
215-
particles, log_w = _stop_gradient_resample(
216-
key_resample, particles, log_w, base_method="multinomial"
319+
def _resample(args):
320+
x_in, log_w_in, key_in = args
321+
if resample_method == "soft":
322+
return _soft_resample(key_in, x_in, log_w_in, hyperparams.softness)
323+
if resample_method == "multinomial":
324+
x_out, log_w_out, _ = _multinomial_resample(key_in, x_in, log_w_in)
325+
return x_out, log_w_out
326+
return _stop_gradient_resample(
327+
key_in, x_in, log_w_in, base_method="multinomial"
217328
)
218329

219-
return (particles, log_w, log_evidence), (particles_hist, logw_hist)
330+
particles, log_w = lax.cond(
331+
ess < ess_threshold,
332+
_resample,
333+
lambda args: (args[0], args[1]),
334+
(particles, log_w, key_resample),
335+
)
220336

221-
(particles, log_w, log_evidence), (particles_hist, logw_hist) = lax.scan(
337+
return (particles, log_w, log_evidence), (particles_hist, logw_hist, ess)
338+
339+
(particles, log_w, log_evidence), (particles_hist, logw_hist, ess_hist) = lax.scan(
222340
_step,
223341
(particles, log_w, log_evidence),
224342
(keys, idxs, t_currs, t_prevs, u_prevs, u_currs),
225343
)
226-
227-
return particles_hist, logw_hist, log_evidence
344+
if hyperparams.return_ess_history:
345+
return particles_hist, logw_hist, ess_hist, log_evidence
346+
else:
347+
return particles_hist, logw_hist, log_evidence

0 commit comments

Comments
 (0)