Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,5 @@ cython_debug/

# Version files made by setuptools_scm
**/version.py

.vscode/
159 changes: 127 additions & 32 deletions src/ldp/alg/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import itertools
import logging
import uuid
from collections import Counter
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager, nullcontext
from typing import Any, TypeVar, overload

from aviary.core import Environment, Message
from tqdm.asyncio import tqdm

from ldp.agent import Agent
from ldp.data_structures import Trajectory, Transition
Expand All @@ -24,6 +26,7 @@ class CaughtError(Exception):
"""Base class for reraised exceptions when catching is enabled."""

def __init__(self, original_exc: Exception):
super().__init__(str(original_exc))
self.original_exc = original_exc

exc_type = "undefined"
Expand All @@ -39,12 +42,12 @@ class EnvError(CaughtError):

@contextmanager
def reraise_exc_as(reraise: type[CaughtError], enabled: bool) -> Iterator[None]:
"""Context manager that reraises exceptions as a custom CaughtError type if enabled."""
try:
yield
except Exception as e:
if enabled:
error_details = format_error_details(e)
logger.exception(f"Caught {reraise.exc_type} exception:\n{error_details}")
logger.debug(f"Reraising {reraise.exc_type} exception.")
raise reraise(e) from None
raise

Expand Down Expand Up @@ -106,6 +109,9 @@ async def sample_trajectories( # noqa: D418
environments: A list of environments to run rollouts on.
max_steps: Max steps per rollout. Defaults to None, in which case the rollouts are run
until environment returns done.
log_exceptions_immediately: Whether to log exceptions in the rollout immediately
to the console. Defaults to True. If False, progress bar will show and a summary
will be logged after all rollouts are complete.
"""

async def sample_trajectories(self, **kwargs):
Expand All @@ -118,14 +124,21 @@ async def sample_trajectories(self, **kwargs):
kwargs["environment_factory"],
kwargs.get("batch_size", 1),
kwargs.get("max_steps"),
log_exceptions_immediately=kwargs.get(
"log_exceptions_immediately", True
),
)

if "environments" in kwargs:
assert "environment_factory" not in kwargs, (
"Cannot use environments with environment_factory"
)
return await self._sample_trajectories_from_envs(
kwargs["environments"], kwargs.get("max_steps")
kwargs["environments"],
kwargs.get("max_steps"),
log_exceptions_immediately=kwargs.get(
"log_exceptions_immediately", True
),
)

raise TypeError(
Expand All @@ -138,13 +151,18 @@ async def _sample_trajectories_from_env_factory(
environment_factory: Callable[[], Environment],
batch_size: int = 1,
max_steps: int | None = None,
*,
log_exceptions_immediately: bool = True,
) -> list[tuple[Trajectory, Environment]]:
self.traj_buffer.clear()
exception_counter: Counter = Counter()

async def rollout_with_args(idx: int, **rollout_kwargs):
return idx, await self._rollout(**rollout_kwargs), rollout_kwargs

accumulated_steps = [0] * batch_size
total_trajectories = 0 # Counter for completed trajectories

# submit initial batch of tasks
tasks = [
asyncio.create_task(
Expand All @@ -153,61 +171,134 @@ async def rollout_with_args(idx: int, **rollout_kwargs):
traj_id=uuid.uuid4().hex,
env=environment_factory(),
max_steps=max_steps,
log_exceptions_immediately=log_exceptions_immediately,
)
)
for idx in range(batch_size)
]

results = []
while tasks:
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
new_tasks = []
for task in done:
idx, traj, kwargs = await task
results.append((traj, kwargs["env"]))
accumulated_steps[idx] += len(traj.steps)
if (
max_steps is not None
and (remaining_steps := max_steps - accumulated_steps[idx]) > 0
):
# submit another task if we haven't reached max_steps
new_task = asyncio.create_task(
rollout_with_args(
idx,
traj_id=uuid.uuid4().hex,
env=environment_factory(),
max_steps=remaining_steps,
with tqdm(
desc="Rollouts",
unit="rollout",
ncols=0,
disable=log_exceptions_immediately,
) as pbar:
while tasks:
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
new_tasks = []
for task in done:
idx, traj, kwargs = await task
results.append((traj, kwargs["env"]))
total_trajectories += 1
pbar.update(1)

steps_in_traj = len(traj.steps)
accumulated_steps[idx] += steps_in_traj

# Check for exceptions in this trajectory
if traj.steps and traj.steps[-1].metadata.get("exception"):
exc_str: str = str(traj.steps[-1].metadata["exception"])[
:500
].replace('"', "'")
exception_counter[exc_str] += 1
num_exceptions = sum(exception_counter.values())
pbar.set_postfix({"num_exceptions": num_exceptions})

if (
max_steps is not None
and (remaining_steps := max_steps - accumulated_steps[idx]) > 0
):
# submit another task if we haven't reached max_steps
new_task = asyncio.create_task(
rollout_with_args(
idx,
traj_id=uuid.uuid4().hex,
env=environment_factory(),
max_steps=remaining_steps,
log_exceptions_immediately=log_exceptions_immediately,
)
)
)
new_tasks.append(new_task)
new_tasks.append(new_task)

tasks = list(pending) + new_tasks
tasks = list(pending) + new_tasks

# Final summary of exceptions (if any)
if exception_counter and not log_exceptions_immediately:
summary = ["Caught exceptions:", "Count Exception"]
summary.extend(
f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items()
)
logger.info("\n".join(summary))

return results

async def _sample_trajectories_from_envs(
self,
environments: Sequence[Environment],
max_steps: int | None = None,
*,
log_exceptions_immediately: bool = True,
) -> list[Trajectory]:
self.traj_buffer.clear()
exception_counter: Counter = Counter()

traj_ids = [uuid.uuid4().hex for _ in environments]

traj_ids = [uuid.uuid4().hex for _ in range(len(environments))]
await asyncio.gather(
*(
self._rollout(*args, max_steps=max_steps)
for args in zip(traj_ids, environments, strict=True)
# Create all tasks first
tasks = [
asyncio.create_task(
self._rollout(
traj_id,
env,
max_steps=max_steps,
log_exceptions_immediately=log_exceptions_immediately,
)
)
)
for traj_id, env in zip(traj_ids, environments, strict=True)
]

with tqdm(
total=len(tasks),
desc="Rollouts",
unit="rollout",
ncols=0,
disable=log_exceptions_immediately,
) as pbar:
for task in asyncio.as_completed(tasks):
trajectory = await task
pbar.update(1)
# Check if this trajectory ended with an exception
if trajectory.steps:
last_step = trajectory.steps[-1]
if last_step.metadata.get("exception"):
# We'll keep it short but still have something to categorize
exc_str: str = str(last_step.metadata["exception"])[
:500
].replace('"', "'")
exception_counter[exc_str] += 1
num_exceptions = sum(exception_counter.values())
pbar.set_postfix({"num_exceptions": num_exceptions})

# Final summary of exceptions (if any)
if exception_counter and not log_exceptions_immediately:
summary = ["Caught exceptions:", "Count Exception"]
summary.extend(
f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items()
)
logger.info("\n".join(summary))

return [self.traj_buffer[traj_id] for traj_id in traj_ids]

async def _rollout(
self,
traj_id: str,
env: Environment,
max_steps: int | None,
*,
log_exceptions_immediately: bool = True,
) -> Trajectory:
trajectory = Trajectory(traj_id=traj_id)

Expand Down Expand Up @@ -260,6 +351,10 @@ async def store_step(step: Transition):
except CaughtError as e:
# NOTE: This trajectory should not be used for regular training.
# We save the last transition here for debugging, etc.
if log_exceptions_immediately:
error_details = format_error_details(e.original_exc)
logger.exception(f"Exception in rollout {traj_id}:\n{error_details}")

await store_step(
Transition(
timestep=len(trajectory.steps),
Expand Down
16 changes: 14 additions & 2 deletions src/ldp/graph/async_torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = ["AsyncTorchModule", "async_protect_torch_call"]

import asyncio
import logging
import operator
import time
from abc import ABC, abstractmethod
Expand All @@ -19,6 +20,9 @@
"Please run `pip install ldp[nn]`."
) from None


logger = logging.getLogger(__name__)

_TORCH_LOCK = asyncio.Lock()

# Supported devices here: https://pytorch.org/docs/stable/amp.html#torch.autocast
Expand Down Expand Up @@ -90,6 +94,7 @@ def __init__(
self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = []
self._result_buffer: dict[UUID, Any] = {}
self._lock = asyncio.Lock()
self._exception_raised: Exception | None = None

async def __call__(self, **kwargs):
request_id = uuid4()
Expand All @@ -101,16 +106,23 @@ async def __call__(self, **kwargs):

while True:
async with self._lock:
if self._exception_raised is not None:
logger.info("Exception raised in another coroutine")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this warning or error level?

raise self._exception_raised

# Only one coroutine allowed in here when:
# - modifying the result buffer
# - modifying the work buffer

if request_id in self._result_buffer:
# Our request was fulfilled by this or another coroutine!
return self._result_buffer.pop(request_id)

# Try to run a batch.
await self._maybe_process_batch()
try:
await self._maybe_process_batch()
except Exception as e:
self._exception_raised = e
raise

# Sleep, to let another coroutine take over if it needs to
await asyncio.sleep(0.0)
Expand Down
32 changes: 32 additions & 0 deletions src/ldp/nn/agent/simple_local_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from typing import cast

import torch
import torch.distributed as dist
from aviary.core import Message, Tool, ToolRequestMessage
from litellm.utils import token_counter
from pydantic import Field, field_validator

from ldp.agent import Agent, SimpleAgentState
Expand All @@ -17,6 +19,8 @@
)
from ldp.nn.lm_config import LMConfig as _LMConfig

logger = logging.getLogger(__name__)


class AgentLMConfig(_LMConfig):
"""Adds some additional configuration options for running an LM in an Op."""
Expand All @@ -42,6 +46,10 @@ class AgentLMConfig(_LMConfig):
),
validate_default=True,
)
max_messages_token_count: int | None = Field(
default=None,
description="If set, raise an error if the total tokens in the trajectory exceed this value.",
)

@field_validator("llm_call_kwargs")
@classmethod
Expand Down Expand Up @@ -91,6 +99,8 @@ async def get_asv(
else next_state.messages
)

self._validate_token_count(messages, next_state.tools)

# Execute the LLM operation call
result = cast(
"OpResult[Message | ToolRequestMessage]",
Expand All @@ -112,8 +122,30 @@ async def get_asv(

# Update state messages with result and return the new state
next_state.messages = [*next_state.messages, result.value]
self._validate_token_count(next_state.messages, next_state.tools)

return cast("OpResult[ToolRequestMessage]", result), next_state, 0.0

def _validate_token_count(self, messages: list[Message], tools: list[Tool]):
"""Asserts token count for the trajectory is within the limit."""
if self.llm_model.max_messages_token_count is None:
return
messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer(messages)
tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer(tools)

total_tokens = token_counter(
model=self.llm_model.model,
messages=messages_for_tokenizer,
tools=tools_for_tokenizer, # type: ignore[arg-type]
)
if total_tokens > self.llm_model.max_messages_token_count:
logger.error(
f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}"
)
raise ValueError(
f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}"
)

# TODO: maybe remove these recomputation methods. I added them to debug some things. But idk,
# maybe they'll come in handy later.
@staticmethod
Expand Down
Loading
Loading