Skip to content

Commit a997db0

Browse files
authored
Rollout timing metadata (#352)
1 parent c93a76c commit a997db0

File tree

2 files changed

+67
-17
lines changed

2 files changed

+67
-17
lines changed

src/ldp/alg/rollout.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import itertools
33
import logging
4+
import time
45
import traceback
56
import uuid
67
from collections import Counter
@@ -65,6 +66,21 @@ def reraise_exc_as(reraise: type[CaughtError], enabled: bool) -> Iterator[None]:
6566
raise
6667

6768

69+
class Timer:
70+
"""Tracks time spent in named operations."""
71+
72+
def __init__(self):
73+
self.info: dict[str, float] = {}
74+
75+
@contextmanager
76+
def __call__(self, name: str):
77+
start_time = time.monotonic()
78+
try:
79+
yield
80+
finally:
81+
self.info[f"time_elapsed_{name}"] = time.monotonic() - start_time
82+
83+
6884
class RolloutManager:
6985
def __init__(
7086
self,
@@ -313,6 +329,7 @@ async def _rollout(
313329
summarize_exceptions: bool = False,
314330
) -> Trajectory:
315331
trajectory = await Trajectory.from_env(env, traj_id=traj_id)
332+
timer = Timer()
316333

317334
async def store_step(step: Transition):
318335
await asyncio.gather(*[
@@ -343,7 +360,9 @@ async def store_step(step: Transition):
343360
])
344361

345362
for timestep in itertools.count():
346-
step = await self._take_step(timestep, traj_id, env, agent_state, obs)
363+
step = await self._take_step(
364+
timestep, traj_id, env, agent_state, obs, timer
365+
)
347366

348367
if timestep + 1 == max_steps and not step.done:
349368
# Mark as truncated if we hit max_steps and the state is not terminal.
@@ -376,7 +395,7 @@ async def store_step(step: Transition):
376395
next_observation=[],
377396
action=None,
378397
done=True,
379-
metadata={"exception": repr(e.original_exc)},
398+
metadata={"exception": repr(e.original_exc)} | timer.info,
380399
)
381400
)
382401

@@ -390,30 +409,47 @@ async def _take_step(
390409
env: Environment,
391410
agent_state: Any,
392411
obs: list[Message],
412+
timer: Timer | None = None,
393413
) -> Transition:
394-
async with self.concurrency_limiter:
395-
await asyncio.gather(*[
396-
callback.before_transition(traj_id, self.agent, env, agent_state, obs)
397-
for callback in self.callbacks
398-
])
414+
timer = timer or Timer()
399415

400-
with reraise_exc_as(AgentError, enabled=self.catch_agent_failures):
416+
async with self.concurrency_limiter:
417+
with timer("before_transition"):
418+
await asyncio.gather(*[
419+
callback.before_transition(
420+
traj_id, self.agent, env, agent_state, obs
421+
)
422+
for callback in self.callbacks
423+
])
424+
425+
with (
426+
timer("agent_get_asv"),
427+
reraise_exc_as(AgentError, enabled=self.catch_agent_failures),
428+
):
401429
(
402430
action,
403431
next_agent_state,
404432
value,
405433
) = await self.agent.get_asv(agent_state, obs)
406-
await asyncio.gather(*[
407-
callback.after_agent_get_asv(traj_id, action, next_agent_state, value)
408-
for callback in self.callbacks
409-
])
410434

411-
with reraise_exc_as(EnvError, enabled=self.catch_env_failures):
435+
with timer("after_agent_get_asv"):
436+
await asyncio.gather(*[
437+
callback.after_agent_get_asv(
438+
traj_id, action, next_agent_state, value
439+
)
440+
for callback in self.callbacks
441+
])
442+
443+
with (
444+
timer("env_step"),
445+
reraise_exc_as(EnvError, enabled=self.catch_env_failures),
446+
):
412447
next_obs, reward, done, trunc = await env.step(action.value)
413-
await asyncio.gather(*[
414-
callback.after_env_step(traj_id, next_obs, reward, done, trunc)
415-
for callback in self.callbacks
416-
])
448+
with timer("after_env_step"):
449+
await asyncio.gather(*[
450+
callback.after_env_step(traj_id, next_obs, reward, done, trunc)
451+
for callback in self.callbacks
452+
])
417453

418454
return Transition(
419455
timestep=timestep,
@@ -426,4 +462,5 @@ async def _take_step(
426462
next_observation=next_obs,
427463
done=done,
428464
truncated=trunc,
465+
metadata=timer.info,
429466
)

tests/test_rollouts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,19 @@ async def test_rollout(training: bool) -> None:
9292
)
9393
assert second_traj.metadata.get("env_id") is None
9494

95+
assert all(
96+
tx.metadata.get(f"time_elapsed_{fn}") is not None
97+
for fn in (
98+
"before_transition",
99+
"agent_get_asv",
100+
"after_agent_get_asv",
101+
"env_step",
102+
"after_env_step",
103+
)
104+
for traj in trajs
105+
for tx in traj.steps
106+
), "All transitions should have timing metadata"
107+
95108
# Let's check we can serialize and deserialize the trajectories
96109
for traj in trajs:
97110
assert traj.traj_id

0 commit comments

Comments
 (0)