Skip to content

Commit dfaba22

Browse files
cwenigerclaude
andauthored
Ref-based object store with dict trace and three-layer architecture (#24)
* Implement ref-based data passing between actors Convert data passing to use Ray ObjectRefs throughout, eliminating unnecessary serialization through the driver. Key changes: - NodeWrapper.sample() now returns List[Dict[str, ObjectRef]] with internal chunking based on ray.chunk_size config - Add _simulate(), _sample_posterior(), _sample_proposal() as internal methods for actual sampling logic - Add _convert_rvbatch_to_refs() helper for RVBatch to refs conversion - MultiplexNodeWrapper updated for ref-based multiplexed sampling - DeployedGraph gains _resolve_refs_to_trace(), _merge_refs(), _refs_to_arrays() helpers for graph traversal with refs - DatasetManagerActor.append_refs() stores ObjectRefs directly - Add _dump_refs() for lazy disk dumps when store_fraction > 0 - Fix resampling bug by merging proposal_refs with sample_refs This enables zero-copy data flow between simulation and buffer actors, supports heterogeneous sample shapes, and makes the buffer shape-agnostic. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Simplify internal sampling interface with (value, logprob) tuples Replace RVBatch wrapper handling with explicit tuple returns: - _simulate() now returns (value, logprob) tuple - _sample_posterior() returns (value, logprob) tuple - _sample_proposal() returns (value, logprob) tuple - Rename _convert_rvbatch_to_refs() to _batch_to_refs(value, logprob) This removes hasattr checks and as_rvbatch() calls, making the internal interface explicit. RVBatch is still used by estimators internally, but the boundary is now a simple tuple. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Replace RVBatch with dict-based trace and ref-only outer layer Three-layer architecture: outer (DeployedGraph/store) passes only ObjectRefs, NodeWrapper handles ref↔array boundary with batched ray.get/put, inner (estimators/simulators) sees pure arrays/tensors. - Remove RVBatch/RV classes entirely - Estimators return {'value': arr, 'log_prob': arr} dicts - Flat store keys: theta.value, theta.log_prob (not bare theta) - NodeWrapper: unified _chunked_sample, _resolve_refs (batched ray.get), _batch_to_refs (phased puts), condition_refs interface - MultiplexNodeWrapper: factored _multiplexed_call, slices ref lists - DeployedGraph: ref-only _execute_graph (no double serialization), _arrays_to_condition_refs, _extract_value_refs, batched _refs_to_arrays - Fix cli.py sample_mode for new ref-based return format - Fix read-only array warning from Ray zero-copy returns Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Remove unnecessary pause/resume during resimulation sample_proposal interleaves with training via async yield points (await asyncio.sleep(0) in the training loop). Forward simulation runs on separate actors concurrently with training. No need to pause training during the entire resimulation cycle. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Optimize ref resolution: broadcast detection and threaded ray.put _resolve_refs now detects when all refs are identical (broadcast case, e.g. repeated observations) and resolves once with np.broadcast_to instead of fetching N copies. Reduces resolve_refs from ~280ms to ~1ms for broadcast observations. _batch_to_refs uses ThreadPoolExecutor(4) for ray.put calls, giving ~1.3x speedup for small arrays and ~4x for large arrays (GIL released during serialization). _to_tensor handles read-only numpy arrays (from np.broadcast_to) by copying before torch.from_numpy conversion. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Avoid materializing broadcast arrays in _to_tensor Use np.asarray (no-op) instead of np.array (100MB copy) for read-only broadcast arrays from _resolve_refs. The PyTorch non-writable warning is harmless since .to(device) always creates a writable copy. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Defer broadcast expansion and background data fetching in core Lazy broadcast: _resolve_refs now returns compact (1,...) arrays for broadcast conditions, deferring np.broadcast_to to the point of use in _simulate. Reduces memory allocation during chunked sampling. Background fetch: CachedDataLoader.sync() offloads ray.get to a background thread so the training loop is not blocked. New data is applied one epoch later; first sync blocks to ensure initial data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Move np.stack + tensor conversion into background fetch thread _apply_fetch was taking ~4s per sync (1024 samples) because it ran np.stack + _to_tensor per-sample on the main thread, blocking GPU training. Now _checkout_and_fetch does the full pipeline in the background thread (checkout_refs, ray.get, np.stack, torch.as_tensor), returning ready-to-use tensors. _apply_fetch on the main thread only does fast batched scatter/cat (~0.2s). Also makes sync() truly non-blocking by checking .done() before applying. Measured: _apply_fetch 4.0s → 0.2s, epoch time 14.2s → 10s. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix dtype mismatch in _apply_fetch scatter assignment Background thread's torch.as_tensor may produce different precision (e.g., Float vs Double) than the existing cache arrays. Cast new tensors to match the destination dtype before indexed assignment. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Remove dead code and fix missing joblib import Cleanup: - Remove unused call_simulator_method, call_estimator_method (NodeWrapper) - Remove duplicate shutdown() stub in NodeWrapper - Remove unused shutdown() in DatasetManagerActor - Remove unused load_observations() and its imports from utils.py - Remove leftover debug log in _merge_refs - Simplify _execute_graph: deduplicate parent loop in conditional Bugfix: - Add missing joblib import in raystore.py (used by load_initial_samples) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Suppress Ray blocking ray.get warning in async training loop buffer.get_stats() does a synchronous ray.get to fetch buffer statistics. This triggers a Ray warning about blocking in async actors. The call is fast (<1ms) and only happens once per epoch, so the blocking is negligible. Added warnings.catch_warnings context manager to suppress this specific warning at the two call sites in stepwise_estimator.py, with TODO noting potential future improvements (async get_stats or background caching). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Suppress Ray 'blocking ray.get inside async actor' warning Set Ray's internal blocking_get_inside_async_warned flag to True in NodeWrapper.__init__ to prevent the warning from being logged. This warning is unavoidable without splitting async training actors from synchronous sampling actors. Also removes ineffective warnings.filterwarnings context managers from stepwise_estimator.py that were previously added but didn't work (Ray uses logging.warning, not warnings.warn). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 260d871 commit dfaba22

File tree

8 files changed

+583
-353
lines changed

8 files changed

+583
-353
lines changed

examples/05_linear_regression/config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ graph:
110110

111111
ray:
112112
num_gpus: 0.5
113+
chunk_size: 128
113114

114115
y:
115116
parents: [theta]
@@ -120,6 +121,7 @@ graph:
120121
observed: "./data/mock_data.npz['y']"
121122
ray:
122123
num_gpus: 0.5
124+
chunk_size: 128
123125

124126
# -----------------------------------------------------------------------------
125127
# Sampling configuration

falcon/cli.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -411,46 +411,54 @@ def sample_mode(cfg, sample_type: str) -> None:
411411

412412
if sample_type == "prior":
413413
# Generate forward samples from prior
414-
samples = deployed_graph.sample(num_samples)
414+
sample_refs = deployed_graph.sample(num_samples)
415415

416416
elif sample_type == "posterior":
417-
# TODO: Implement posterior sampling (requires trained model and observations)
418417
deployed_graph.load(Path(cfg.paths.graph))
419-
samples = deployed_graph.sample_posterior(num_samples, observations)
418+
sample_refs = deployed_graph.sample_posterior(num_samples, observations)
420419

421420
elif sample_type == "proposal":
422-
# Proposal sampling requires observations for conditioning
423-
# Load observations from config
424421
deployed_graph.load(Path(cfg.paths.graph))
425-
samples = deployed_graph.sample_proposal(num_samples, observations)
422+
sample_refs = deployed_graph.sample_proposal(num_samples, observations)
426423

427424
else:
428425
raise ValueError(f"Unknown sample type: {sample_type}")
429426

430-
# Apply smart key selection based on mode and user overrides
427+
# Resolve refs to arrays (keys are flat: 'theta.value', 'theta.log_prob', 'x.value', ...)
428+
samples = deployed_graph._refs_to_arrays(sample_refs)
429+
430+
# Build key selection based on node names (strip .value/.log_prob suffixes)
431+
node_keys = {k for k in samples.keys() if k.endswith('.value')}
432+
431433
if sample_type in ["prior", "proposal"]:
432-
# Default: save everything
433-
default_keys = set(samples.keys())
434+
# Default: save all .value keys
435+
default_keys = set(node_keys)
434436
elif sample_type == "posterior":
435437
# Default: save only posterior nodes (nodes with evidence)
436438
default_keys = {
437-
k for k, node in graph.node_dict.items() if node.evidence and k in samples
439+
f"{k}.value" for k, node in graph.node_dict.items()
440+
if node.evidence and f"{k}.value" in samples
438441
}
439442

440-
# Apply user overrides
443+
# Apply user overrides (user specifies node names, we match .value keys)
441444
exclude_keys = sample_cfg.get("exclude_keys", None)
442445
add_keys = sample_cfg.get("add_keys", None)
443446

444447
if exclude_keys:
445-
exclude_set = set(exclude_keys.split(","))
448+
exclude_set = {f"{k}.value" for k in exclude_keys.split(",")}
446449
default_keys -= exclude_set
447450

448451
if add_keys:
449-
add_set = set(add_keys.split(","))
452+
add_set = {f"{k}.value" for k in add_keys.split(",")}
450453
default_keys |= add_set
451454

452-
# Filter samples to selected keys
453-
save_data = {k: samples[k] for k in default_keys if k in samples}
455+
# Filter samples to selected keys, strip .value suffix for user-facing output
456+
save_data = {}
457+
for k in default_keys:
458+
if k in samples:
459+
# Strip '.value' suffix for cleaner output key names
460+
user_key = k[:-6] if k.endswith('.value') else k
461+
save_data[user_key] = samples[k]
454462

455463
print(f"Generated samples with shapes:")
456464
for key, value in save_data.items():

falcon/contrib/SNPE_A.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch.optim import AdamW
1313
from torch.optim.lr_scheduler import ReduceLROnPlateau
1414

15-
from falcon.core.utils import RVBatch
1615
from falcon.core.logger import log, debug, info, warning, error
1716
from falcon.contrib.flow import Flow
1817
from falcon.contrib.stepwise_estimator import StepwiseEstimator, TrainingLoopConfig
@@ -228,10 +227,10 @@ def _unpack_batch(self, batch, phase: str):
228227
Tuple of (ids, theta, theta_logprob, conditions, u, u_device, conditions_device)
229228
"""
230229
ids = batch._ids
231-
theta = self._to_tensor(batch[self.theta_key])
232-
theta_logprob = self._to_tensor(batch[f"{self.theta_key}.logprob"])
230+
theta = self._to_tensor(batch[f"{self.theta_key}.value"])
231+
theta_logprob = self._to_tensor(batch[f"{self.theta_key}.log_prob"])
233232
conditions = {
234-
k: self._to_tensor(batch[k]) for k in self.condition_keys if k in batch
233+
k: self._to_tensor(batch[f"{k}.value"]) for k in self.condition_keys if f"{k}.value" in batch
235234
}
236235

237236
# Record IDs for history
@@ -345,34 +344,34 @@ def on_epoch_end(self, epoch: int, val_metrics: Dict[str, float]) -> Optional[Di
345344

346345
# ==================== Sampling Methods ====================
347346

348-
def sample_prior(self, num_samples: int, conditions: Optional[Dict] = None) -> RVBatch:
347+
def sample_prior(self, num_samples: int, conditions: Optional[Dict] = None) -> dict:
349348
"""Sample from the prior distribution."""
350349
if conditions:
351350
raise ValueError("Conditions are not supported for sample_prior.")
352351
samples = self.simulator_instance.simulate_batch(num_samples)
353352
# Log probability for uniform prior over hypercube [-bound, bound]^d
354353
bound = self.config.inference.hypercube_bound
355-
logprob = np.ones(num_samples) * (-np.log(2 * bound) ** self.param_dim)
356-
return RVBatch(samples, logprob=logprob)
354+
log_prob = np.ones(num_samples) * (-np.log(2 * bound) ** self.param_dim)
355+
return {'value': samples, 'log_prob': log_prob}
357356

358357
def sample_posterior(
359358
self,
360359
num_samples: int,
361360
conditions: Optional[Dict] = None,
362-
) -> RVBatch:
361+
) -> dict:
363362
"""Sample from the posterior distribution q(theta|x)."""
364363
# Fall back to prior if networks not yet initialized (training hasn't started)
365364
if not self.networks_initialized:
366365
return self.sample_prior(num_samples)
367366

368367
samples, logprob = self._importance_sample(num_samples, mode="posterior", conditions=conditions or {})
369-
return RVBatch(samples.numpy(), logprob=logprob.numpy())
368+
return {'value': samples.numpy(), 'log_prob': logprob.numpy()}
370369

371370
def sample_proposal(
372371
self,
373372
num_samples: int,
374373
conditions: Optional[Dict] = None,
375-
) -> RVBatch:
374+
) -> dict:
376375
"""Sample from the widened proposal distribution for adaptive resampling."""
377376
# Fall back to prior if networks not yet initialized (training hasn't started)
378377
if not self.networks_initialized:
@@ -393,7 +392,7 @@ def sample_proposal(
393392
"sample_proposal:std": samples.std().item(),
394393
"sample_proposal:logprob": logprob.mean().item(),
395394
})
396-
return RVBatch(samples.numpy(), logprob=logprob.numpy())
395+
return {'value': samples.numpy(), 'log_prob': logprob.numpy()}
397396

398397
def _importance_sample(
399398
self,
@@ -405,8 +404,8 @@ def _importance_sample(
405404
cfg_inf = self.config.inference
406405

407406
assert conditions, "Conditions must be provided."
408-
# Move conditions to device
409-
conditions = {k: v.to(self.device) for k, v in conditions.items()}
407+
# Move conditions to device (handles both numpy arrays and tensors)
408+
conditions = {k: self._to_tensor(v, self.device) for k, v in conditions.items()}
410409

411410
# Use best models if available and configured, otherwise fall back to current
412411
use_best = cfg_inf.use_best_models_during_inference and self._best_conditional_flow is not None

falcon/contrib/stepwise_estimator.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torch.optim.lr_scheduler import ReduceLROnPlateau
1616

1717
from falcon.core.base_estimator import BaseEstimator
18-
from falcon.core.utils import RVBatch
1918
from falcon.core.logger import log, debug, info, warning, error
2019

2120

@@ -117,7 +116,7 @@ def _to_tensor(x, device=None):
117116
"""Convert numpy array or torch tensor to the target device."""
118117
if isinstance(x, torch.Tensor):
119118
return x if device is None else x.to(device)
120-
return torch.from_numpy(x) if device is None else torch.from_numpy(x).to(device)
119+
return torch.from_numpy(np.asarray(x)) if device is None else torch.from_numpy(np.asarray(x)).to(device)
121120

122121
# ==================== Abstract Methods ====================
123122

@@ -185,7 +184,8 @@ async def train(self, buffer) -> None:
185184
buffer: BufferView providing access to training/validation data
186185
"""
187186
cfg = self.loop_config
188-
keys = [self.theta_key, f"{self.theta_key}.logprob", *self.condition_keys]
187+
keys = [f"{self.theta_key}.value", f"{self.theta_key}.log_prob",
188+
*[f"{k}.value" for k in self.condition_keys]]
189189
await self._train(buffer, cfg, keys)
190190

191191
async def _train(self, buffer, cfg, keys) -> None:
@@ -430,9 +430,9 @@ def _build_model(self, batch) -> nn.Module:
430430
from falcon.contrib.torch_embedding import instantiate_embedding
431431

432432
# Extract and store tensors for reload
433-
self._init_theta = self._to_tensor(batch[self.theta_key])
433+
self._init_theta = self._to_tensor(batch[f"{self.theta_key}.value"])
434434
self._init_conditions = {
435-
k: self._to_tensor(batch[k]) for k in self.condition_keys if k in batch
435+
k: self._to_tensor(batch[f"{k}.value"]) for k in self.condition_keys if f"{k}.value" in batch
436436
}
437437
return self._create_model(self._init_theta, self._init_conditions)
438438

@@ -471,11 +471,11 @@ def _create_model(self, theta: torch.Tensor, conditions: Dict[str, torch.Tensor]
471471

472472
def _compute_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, float]]:
473473
"""Compute loss from batch."""
474-
theta = self._to_tensor(batch[self.theta_key], self.device)
475-
theta_logprob = self._to_tensor(batch[f"{self.theta_key}.logprob"])
474+
theta = self._to_tensor(batch[f"{self.theta_key}.value"], self.device)
475+
theta_logprob = self._to_tensor(batch[f"{self.theta_key}.log_prob"])
476476
conditions = {
477-
k: self._to_tensor(batch[k], self.device)
478-
for k in self.condition_keys if k in batch
477+
k: self._to_tensor(batch[f"{k}.value"], self.device)
478+
for k in self.condition_keys if f"{k}.value" in batch
479479
}
480480

481481
# Transform theta to latent space if mode specified
@@ -601,49 +601,49 @@ def has_trained_model(self) -> bool:
601601

602602
# ==================== Sampling Methods ====================
603603

604-
def sample_prior(self, num_samples: int, conditions: Optional[Dict] = None) -> RVBatch:
604+
def sample_prior(self, num_samples: int, conditions: Optional[Dict] = None) -> dict:
605605
"""Sample from the prior distribution."""
606606
if conditions:
607607
raise ValueError("Conditions are not supported for sample_prior.")
608608
samples = self.simulator_instance.simulate_batch(num_samples)
609-
logprob = np.zeros(num_samples)
610-
return RVBatch(samples, logprob=logprob)
609+
log_prob = np.zeros(num_samples)
610+
return {'value': samples, 'log_prob': log_prob}
611611

612-
def _sample(self, num_samples: int, conditions: Optional[Dict], gamma: Optional[float]) -> RVBatch:
612+
def _sample(self, num_samples: int, conditions: Optional[Dict], gamma: Optional[float]) -> dict:
613613
"""Internal sampling using inference model. Falls back to prior if not trained."""
614614
if not self.has_trained_model:
615615
return self.sample_prior(num_samples)
616616

617617
assert conditions, "Conditions must be provided for sampling."
618618

619619
conditions_device = {
620-
k: v.to(self.device).expand(num_samples, *v.shape[1:])
620+
k: self._to_tensor(v, self.device).expand(num_samples, *v.shape[1:])
621621
for k, v in conditions.items()
622622
}
623623

624624
model = self.inference_model
625625
with torch.no_grad():
626626
model.eval()
627627
samples = model.sample(conditions_device, gamma=gamma)
628-
logprob = model.log_prob(samples, conditions_device)
628+
log_prob = model.log_prob(samples, conditions_device)
629629

630630
# Transform samples from latent space back to theta space
631631
if self.latent_mode is not None:
632632
samples = self.simulator_instance.forward(samples, mode=self.latent_mode)
633633

634-
return RVBatch(samples.cpu().numpy(), logprob=logprob.cpu().numpy())
634+
return {'value': samples.cpu().numpy(), 'log_prob': log_prob.cpu().numpy()}
635635

636-
def sample_posterior(self, num_samples: int, conditions: Optional[Dict] = None) -> RVBatch:
636+
def sample_posterior(self, num_samples: int, conditions: Optional[Dict] = None) -> dict:
637637
"""Sample from the posterior distribution q(theta|x)."""
638638
return self._sample(num_samples, conditions, gamma=None)
639639

640-
def sample_proposal(self, num_samples: int, conditions: Optional[Dict] = None) -> RVBatch:
640+
def sample_proposal(self, num_samples: int, conditions: Optional[Dict] = None) -> dict:
641641
"""Sample from widened proposal distribution for adaptive resampling."""
642642
result = self._sample(num_samples, conditions, gamma=self.inference_config.gamma)
643643
log({
644-
"sample_proposal:mean": result.value.mean(),
645-
"sample_proposal:std": result.value.std(),
646-
"sample_proposal:logprob": result.logprob.mean(),
644+
"sample_proposal:mean": result['value'].mean(),
645+
"sample_proposal:std": result['value'].std(),
646+
"sample_proposal:logprob": result['log_prob'].mean(),
647647
})
648648
return result
649649

falcon/core/base_estimator.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
import torch
88

9-
from falcon.core.utils import RVBatch
10-
119
# Type alias for conditions: maps node names to tensors
1210
Conditions = Dict[str, torch.Tensor]
1311

@@ -20,6 +18,7 @@ class BaseEstimator(ABC):
2018
Concrete implementations must provide all functionality.
2119
2220
Conditions are passed as Dict[str, Tensor] mapping node names to values.
21+
Sampling methods return dicts with 'value' (ndarray) and optionally 'log_prob' (ndarray).
2322
"""
2423

2524
@abstractmethod
@@ -35,7 +34,7 @@ async def train(self, buffer) -> None:
3534
@abstractmethod
3635
def sample_prior(
3736
self, num_samples: int, conditions: Optional[Conditions] = None
38-
) -> RVBatch:
37+
) -> dict:
3938
"""
4039
Sample from the prior distribution.
4140
@@ -44,14 +43,14 @@ def sample_prior(
4443
conditions: Conditioning values from parent nodes (usually None for prior)
4544
4645
Returns:
47-
RVBatch with samples and log probabilities
46+
Dict with 'value' (ndarray) and optionally 'log_prob' (ndarray)
4847
"""
4948
pass
5049

5150
@abstractmethod
5251
def sample_posterior(
5352
self, num_samples: int, conditions: Optional[Conditions] = None
54-
) -> RVBatch:
53+
) -> dict:
5554
"""
5655
Sample from the posterior distribution.
5756
@@ -60,14 +59,14 @@ def sample_posterior(
6059
conditions: Dict mapping node names to condition tensors
6160
6261
Returns:
63-
RVBatch with samples and log probabilities
62+
Dict with 'value' (ndarray) and optionally 'log_prob' (ndarray)
6463
"""
6564
pass
6665

6766
@abstractmethod
6867
def sample_proposal(
6968
self, num_samples: int, conditions: Optional[Conditions] = None
70-
) -> RVBatch:
69+
) -> dict:
7170
"""
7271
Sample from the proposal distribution for adaptive resampling.
7372
@@ -76,7 +75,7 @@ def sample_proposal(
7675
conditions: Dict mapping node names to condition tensors
7776
7877
Returns:
79-
RVBatch with samples and log probabilities
78+
Dict with 'value' (ndarray) and optionally 'log_prob' (ndarray)
8079
"""
8180
pass
8281

0 commit comments

Comments
 (0)