Skip to content

Ref-based object store with dict trace and three-layer architecture#24

Merged
cweniger merged 12 commits intomainfrom
feat/object_based_store
Feb 5, 2026
Merged

Ref-based object store with dict trace and three-layer architecture#24
cweniger merged 12 commits intomainfrom
feat/object_based_store

Conversation

@cweniger
Copy link
Owner

@cweniger cweniger commented Feb 4, 2026

Summary

  • Replace RVBatch with plain dict returns ({'value': ndarray, 'log_prob': ndarray})
  • Flat store keys: theta.value, theta.log_prob instead of bare theta / theta.logprob
  • Ref-only outer layer: DeployedGraph and store never touch arrays, only ObjectRefs
  • NodeWrapper handles ref↔array boundary with batched ray.get / ray.put
  • Remove RV, RVBatch, as_rv, as_rvbatch from utils.py
  • Remove unnecessary pause/resume during resimulation (async actors interleave naturally)

Performance optimizations

Broadcast detection and deferred expansion

Profiling on 05_linear_regression (n=1024 resample batch) identified bottlenecks in sample_proposal that block the training event loop:

Bottleneck Before After Improvement
_resolve_refs (broadcast observation resolved N times) 280ms ~1ms ~280x — detect identical refs, resolve once, return (1,...) shape
_batch_to_refs (per-sample ray.put) 140ms (sequential) ~100ms ThreadPoolExecutor(4) parallelizes puts
Total actor blocking per resim ~430ms ~100ms ~4x less training interruption

Broadcast conditions now stay as shape (1, ...) through _resolve_refs_chunked_sample, expanded only at the final consumer (_simulate via np.broadcast_to, or embeddings via .expand()).

Background data fetch with tensor pre-building

CachedDataLoader.sync() was blocking training for ~4-5s per epoch due to:

  1. checkout_refs + ray.get on the main thread (moved to background in prior commit)
  2. _apply_fetch doing per-sample np.array → torch.as_tensor → .to(device) for 1024 samples × N keys

Fix: The background thread now does the full pipeline — checkout_refs, ray.get, np.stack, and torch.as_tensor().to(device) — returning ready-to-use tensors. _apply_fetch on the main thread only does fast batched indexed scatter and torch.cat.

Metric Before After Improvement
_apply_fetch (1024 samples) 4.0s 0.2s 20x
_apply_fetch (256 val samples) 1.2s 0.1s 12x
sync() total 5.3s (blocking) 0.3s (non-blocking) 18x
Epoch time 14.2s ~10s ~30% faster

sync() is now truly non-blocking: checks .done() before applying, skips if the background fetch is still running.

Test results

All smoke tests pass (pytest tests/test_examples_smoke.py):

  • 01_minimal: 2 epochs
  • 02_bimodal: 2 epochs
  • 03_composite: 2 epochs
  • 04_gaussian: 2 epochs
  • 05_linear_regression: 2 epochs

Total pytest runtime: ~3 minutes (CPU mode)

Test plan

  • Run pytest tests/test_examples_smoke.py — all 5 passed
  • Profile GPU utilization on 05_linear_regression — confirmed 30-70% steady with brief <400ms dips

cweniger and others added 12 commits February 3, 2026 22:25
Convert data passing to use Ray ObjectRefs throughout, eliminating
unnecessary serialization through the driver. Key changes:

- NodeWrapper.sample() now returns List[Dict[str, ObjectRef]] with
  internal chunking based on ray.chunk_size config
- Add _simulate(), _sample_posterior(), _sample_proposal() as internal
  methods for actual sampling logic
- Add _convert_rvbatch_to_refs() helper for RVBatch to refs conversion
- MultiplexNodeWrapper updated for ref-based multiplexed sampling
- DeployedGraph gains _resolve_refs_to_trace(), _merge_refs(),
  _refs_to_arrays() helpers for graph traversal with refs
- DatasetManagerActor.append_refs() stores ObjectRefs directly
- Add _dump_refs() for lazy disk dumps when store_fraction > 0
- Fix resampling bug by merging proposal_refs with sample_refs

This enables zero-copy data flow between simulation and buffer actors,
supports heterogeneous sample shapes, and makes the buffer shape-agnostic.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace RVBatch wrapper handling with explicit tuple returns:
- _simulate() now returns (value, logprob) tuple
- _sample_posterior() returns (value, logprob) tuple
- _sample_proposal() returns (value, logprob) tuple
- Rename _convert_rvbatch_to_refs() to _batch_to_refs(value, logprob)

This removes hasattr checks and as_rvbatch() calls, making the
internal interface explicit. RVBatch is still used by estimators
internally, but the boundary is now a simple tuple.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Three-layer architecture: outer (DeployedGraph/store) passes only
ObjectRefs, NodeWrapper handles ref↔array boundary with batched
ray.get/put, inner (estimators/simulators) sees pure arrays/tensors.

- Remove RVBatch/RV classes entirely
- Estimators return {'value': arr, 'log_prob': arr} dicts
- Flat store keys: theta.value, theta.log_prob (not bare theta)
- NodeWrapper: unified _chunked_sample, _resolve_refs (batched
  ray.get), _batch_to_refs (phased puts), condition_refs interface
- MultiplexNodeWrapper: factored _multiplexed_call, slices ref lists
- DeployedGraph: ref-only _execute_graph (no double serialization),
  _arrays_to_condition_refs, _extract_value_refs, batched _refs_to_arrays
- Fix cli.py sample_mode for new ref-based return format
- Fix read-only array warning from Ray zero-copy returns

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
sample_proposal interleaves with training via async yield points
(await asyncio.sleep(0) in the training loop). Forward simulation
runs on separate actors concurrently with training. No need to
pause training during the entire resimulation cycle.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
_resolve_refs now detects when all refs are identical (broadcast case,
e.g. repeated observations) and resolves once with np.broadcast_to
instead of fetching N copies. Reduces resolve_refs from ~280ms to ~1ms
for broadcast observations.

_batch_to_refs uses ThreadPoolExecutor(4) for ray.put calls, giving
~1.3x speedup for small arrays and ~4x for large arrays (GIL released
during serialization).

_to_tensor handles read-only numpy arrays (from np.broadcast_to) by
copying before torch.from_numpy conversion.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Use np.asarray (no-op) instead of np.array (100MB copy) for read-only
broadcast arrays from _resolve_refs. The PyTorch non-writable warning
is harmless since .to(device) always creates a writable copy.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Lazy broadcast: _resolve_refs now returns compact (1,...) arrays for
broadcast conditions, deferring np.broadcast_to to the point of use in
_simulate. Reduces memory allocation during chunked sampling.

Background fetch: CachedDataLoader.sync() offloads ray.get to a
background thread so the training loop is not blocked. New data is
applied one epoch later; first sync blocks to ensure initial data.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
_apply_fetch was taking ~4s per sync (1024 samples) because it ran
np.stack + _to_tensor per-sample on the main thread, blocking GPU
training. Now _checkout_and_fetch does the full pipeline in the
background thread (checkout_refs, ray.get, np.stack, torch.as_tensor),
returning ready-to-use tensors. _apply_fetch on the main thread only
does fast batched scatter/cat (~0.2s). Also makes sync() truly
non-blocking by checking .done() before applying.

Measured: _apply_fetch 4.0s → 0.2s, epoch time 14.2s → 10s.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Background thread's torch.as_tensor may produce different precision
(e.g., Float vs Double) than the existing cache arrays. Cast new
tensors to match the destination dtype before indexed assignment.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Cleanup:
- Remove unused call_simulator_method, call_estimator_method (NodeWrapper)
- Remove duplicate shutdown() stub in NodeWrapper
- Remove unused shutdown() in DatasetManagerActor
- Remove unused load_observations() and its imports from utils.py
- Remove leftover debug log in _merge_refs
- Simplify _execute_graph: deduplicate parent loop in conditional

Bugfix:
- Add missing joblib import in raystore.py (used by load_initial_samples)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
buffer.get_stats() does a synchronous ray.get to fetch buffer statistics.
This triggers a Ray warning about blocking in async actors. The call is
fast (<1ms) and only happens once per epoch, so the blocking is negligible.

Added warnings.catch_warnings context manager to suppress this specific
warning at the two call sites in stepwise_estimator.py, with TODO noting
potential future improvements (async get_stats or background caching).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Set Ray's internal blocking_get_inside_async_warned flag to True
in NodeWrapper.__init__ to prevent the warning from being logged.
This warning is unavoidable without splitting async training actors
from synchronous sampling actors.

Also removes ineffective warnings.filterwarnings context managers
from stepwise_estimator.py that were previously added but didn't
work (Ray uses logging.warning, not warnings.warn).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@cweniger
Copy link
Owner Author

cweniger commented Feb 5, 2026

Ray Warning Suppression Fix

The "Using blocking ray.get inside async actor" warning has been successfully suppressed.

Root cause: Ray emits this warning via logging.warning(), not warnings.warn(). Previous attempts using warnings.filterwarnings() did not work.

Solution: Set Ray internal blocking_get_inside_async_warned flag to True in NodeWrapper.__init__ before any ray.get calls. This prevents the warning from being logged while preserving the underlying behavior.

The fix is minimal and isolated, with a TODO note about potentially splitting async training actors from synchronous sampling actors in the future.

Copy link
Owner Author

@cweniger cweniger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant