11import asyncio
22import itertools
33import logging
4+ import time
45import traceback
56import uuid
67from collections import Counter
@@ -65,6 +66,21 @@ def reraise_exc_as(reraise: type[CaughtError], enabled: bool) -> Iterator[None]:
6566 raise
6667
6768
69+ class Timer :
70+ """Tracks time spent in named operations."""
71+
72+ def __init__ (self ):
73+ self .info : dict [str , float ] = {}
74+
75+ @contextmanager
76+ def __call__ (self , name : str ):
77+ start_time = time .monotonic ()
78+ try :
79+ yield
80+ finally :
81+ self .info [f"time_elapsed_{ name } " ] = time .monotonic () - start_time
82+
83+
6884class RolloutManager :
6985 def __init__ (
7086 self ,
@@ -313,6 +329,7 @@ async def _rollout(
313329 summarize_exceptions : bool = False ,
314330 ) -> Trajectory :
315331 trajectory = await Trajectory .from_env (env , traj_id = traj_id )
332+ timer = Timer ()
316333
317334 async def store_step (step : Transition ):
318335 await asyncio .gather (* [
@@ -343,7 +360,9 @@ async def store_step(step: Transition):
343360 ])
344361
345362 for timestep in itertools .count ():
346- step = await self ._take_step (timestep , traj_id , env , agent_state , obs )
363+ step = await self ._take_step (
364+ timestep , traj_id , env , agent_state , obs , timer
365+ )
347366
348367 if timestep + 1 == max_steps and not step .done :
349368 # Mark as truncated if we hit max_steps and the state is not terminal.
@@ -376,7 +395,7 @@ async def store_step(step: Transition):
376395 next_observation = [],
377396 action = None ,
378397 done = True ,
379- metadata = {"exception" : repr (e .original_exc )},
398+ metadata = {"exception" : repr (e .original_exc )} | timer . info ,
380399 )
381400 )
382401
@@ -390,30 +409,47 @@ async def _take_step(
390409 env : Environment ,
391410 agent_state : Any ,
392411 obs : list [Message ],
412+ timer : Timer | None = None ,
393413 ) -> Transition :
394- async with self .concurrency_limiter :
395- await asyncio .gather (* [
396- callback .before_transition (traj_id , self .agent , env , agent_state , obs )
397- for callback in self .callbacks
398- ])
414+ timer = timer or Timer ()
399415
400- with reraise_exc_as (AgentError , enabled = self .catch_agent_failures ):
416+ async with self .concurrency_limiter :
417+ with timer ("before_transition" ):
418+ await asyncio .gather (* [
419+ callback .before_transition (
420+ traj_id , self .agent , env , agent_state , obs
421+ )
422+ for callback in self .callbacks
423+ ])
424+
425+ with (
426+ timer ("agent_get_asv" ),
427+ reraise_exc_as (AgentError , enabled = self .catch_agent_failures ),
428+ ):
401429 (
402430 action ,
403431 next_agent_state ,
404432 value ,
405433 ) = await self .agent .get_asv (agent_state , obs )
406- await asyncio .gather (* [
407- callback .after_agent_get_asv (traj_id , action , next_agent_state , value )
408- for callback in self .callbacks
409- ])
410434
411- with reraise_exc_as (EnvError , enabled = self .catch_env_failures ):
435+ with timer ("after_agent_get_asv" ):
436+ await asyncio .gather (* [
437+ callback .after_agent_get_asv (
438+ traj_id , action , next_agent_state , value
439+ )
440+ for callback in self .callbacks
441+ ])
442+
443+ with (
444+ timer ("env_step" ),
445+ reraise_exc_as (EnvError , enabled = self .catch_env_failures ),
446+ ):
412447 next_obs , reward , done , trunc = await env .step (action .value )
413- await asyncio .gather (* [
414- callback .after_env_step (traj_id , next_obs , reward , done , trunc )
415- for callback in self .callbacks
416- ])
448+ with timer ("after_env_step" ):
449+ await asyncio .gather (* [
450+ callback .after_env_step (traj_id , next_obs , reward , done , trunc )
451+ for callback in self .callbacks
452+ ])
417453
418454 return Transition (
419455 timestep = timestep ,
@@ -426,4 +462,5 @@ async def _take_step(
426462 next_observation = next_obs ,
427463 done = done ,
428464 truncated = trunc ,
465+ metadata = timer .info ,
429466 )
0 commit comments