Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
nits
  • Loading branch information
Ori Kabeli committed Feb 18, 2025
commit 13c962224ad7147310b718d912503aebd9ea4290
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,4 @@ cython_debug/
# Version files made by setuptools_scm
**/version.py

.vscode/
.vscode/
18 changes: 8 additions & 10 deletions ldp/alg/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from typing import Any, TypeVar, overload

from aviary.core import Environment, Message
from tqdm import tqdm

from ldp.agent import Agent
from ldp.data_structures import Trajectory, Transition
from ldp.utils import format_error_details

from .callbacks import Callback

Expand Down Expand Up @@ -196,20 +196,16 @@ async def _sample_trajectories_from_envs(
max_steps: int | None = None,
) -> list[Trajectory]:
self.traj_buffer.clear()
exception_counter = Counter()
exception_counter: Counter = Counter()

traj_ids = [uuid.uuid4().hex for _ in environments]

# Create all tasks first
tasks = [
asyncio.create_task(
self._rollout(traj_id, env, max_steps=max_steps)
)
asyncio.create_task(self._rollout(traj_id, env, max_steps=max_steps))
for traj_id, env in zip(traj_ids, environments, strict=True)
]

# Use a single line bar_format to avoid multiline spam.
from tqdm import tqdm
bar_format = (
"{l_bar}{bar} {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]"
" {postfix}"
Expand All @@ -229,17 +225,19 @@ async def _sample_trajectories_from_envs(
last_step = trajectory.steps[-1]
if last_step.metadata.get("exception"):
# We'll keep it short but still have something to categorize
exc_str = last_step.metadata["exception"][:500].replace('"', "'")
exc_str: str = last_step.metadata["exception"][:500]
exc_str = exc_str.replace('"', "'")
exception_counter[exc_str] += 1
num_exceptions = sum(exception_counter.values())
pbar.set_postfix({"num_exceptions": num_exceptions})

# Final summary of exceptions (if any)
if exception_counter:
logger.info("Caught exceptions:")
logger.info("{:<6} {:<50}".format("Count", "Exception"))
logger.info("%-6s %-50s", "Count", "Exception")
for exc, count in exception_counter.items():
logger.info("{:<6} {:<50}".format(count, exc))
logger.info("%-6d %-50s", count, exc)

return [self.traj_buffer[traj_id] for traj_id in traj_ids]

async def _rollout(
Expand Down
12 changes: 9 additions & 3 deletions ldp/graph/async_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,20 @@ async def _maybe_process_batch(self) -> None:

If neither condition is met, do nothing.
"""
# Technically should not happen, but if a coroutine crashes, it could release
# self._lock before placing results in _results_buffer and additional process
# coming inside will crash.
if not self._work_buffer:
return

now = time.time()

# sort by oldest requests first
self._work_buffer.sort(key=operator.itemgetter(0))

if (
len(self._work_buffer) >= self.batch_size
or (now - self._work_buffer[0][0] > self.timeout) and len(self._work_buffer) > 0
if len(self._work_buffer) >= self.batch_size or (
(now - self._work_buffer[0][0] > self.timeout)
and len(self._work_buffer) > 0
):
# if we're over batch size or have at least one input waiting for
# more than timeout, pull out a batch to run
Expand Down
2 changes: 1 addition & 1 deletion ldp/nn/handlers/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _split_value(self, value):
for i in range(self.num_chunks):
if i >= len(chunks):
# Chunk 0 will always exist, and we need only a batch of one ([:1])
# to activate the model.
# to activate the model.
# We use real data to avoid errors in the model expecting certain token structure.
chunks.append(chunks[0][:1])
dummy_chunk_flags.append(True)
Expand Down
15 changes: 15 additions & 0 deletions ldp/nn/handlers/transformer_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import atexit
import logging
import os
import socket
Expand Down Expand Up @@ -193,6 +194,14 @@ async def __call__( # type: ignore[override]
@staticmethod
def model_generate(model: PreTrainedModel, *args, **kwargs):
"""A method that can be used as module_call_fn to sample from an LLM."""
if dist.get_world_size() > 1:
synced_gpus = kwargs.pop("synced_gpus", None)
if synced_gpus is None:
logger.debug("synced_gpus not defined, defaulting to True.")
elif not synced_gpus:
raise ValueError("synced_gpus must be True when using FSDP.")
kwargs["synced_gpus"] = True

# Summoning params per https://github.com/pytorch/pytorch/issues/100069
# If model is not FSDP, this context manager is a no-op.
with FullyShardedDataParallel.summon_full_params(model, recurse=False):
Expand Down Expand Up @@ -463,6 +472,8 @@ def __init__(self, config: TransformerHandlerConfig):

self._initialized = True

atexit.register(self.teardown)

# don't call AsyncTorchModule.__init__ because we don't need to set up module[_call_fn]
AsyncBufferedWorker.__init__(
self,
Expand All @@ -484,6 +495,10 @@ def _init_local_cluster(
# lazy import since dask-cuda only works on Linux machines
from dask_cuda import LocalCUDACluster

# This uses NVIDIA's NVML layer instead of native CUDA, which is more robust in GPU detection
# post initialization. This prevents issues with forked processes wrongly detecting the
# default GPU as cuda:0
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"
self.cluster = LocalCUDACluster(
n_workers=parallel_mode_config.num_workers,
threads_per_worker=parallel_mode_config.num_cpus_per_worker,
Expand Down
Loading