Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor Dask handling in transformer handler for improved exception …
…management and memory efficiency
  • Loading branch information
Ori Kabeli committed Mar 12, 2025
commit 3d7c20e7aef8ea320ad1b5eaef83aadd837a1ba2
2 changes: 1 addition & 1 deletion src/ldp/alg/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, TypeVar

from aviary.core import Environment, Message
from tqdm import tqdm
from tqdm.asyncio import tqdm

from ldp.agent import Agent
from ldp.data_structures import Trajectory, Transition
Expand Down
128 changes: 89 additions & 39 deletions src/ldp/nn/handlers/transformer_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import atexit
import logging
import os
Expand All @@ -10,14 +11,15 @@
from enum import StrEnum, auto
from functools import cache, partial, wraps
from pathlib import Path
from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never, cast
from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never

import accelerate
import torch
import torch.distributed as dist
import tree
from dask import config
from dask.distributed import Client, as_completed
from dask.distributed import Actor, ActorFuture, Client
from distributed.utils import sync
from pydantic import BaseModel, ConfigDict, Field, field_validator
from torch import nn
from torch.cuda import nccl
Expand Down Expand Up @@ -199,6 +201,7 @@ def model_generate(model: PreTrainedModel, *args, **kwargs):
kwargs["synced_gpus"] = True
elif not synced_gpus:
raise ValueError("synced_gpus must be True when using FSDP.")
raise torch.OutOfMemoryError("yoyoyoyoyoyoyo test test test TODO") # TODO remove

# Summoning params per https://github.com/pytorch/pytorch/issues/100069
# If model is not FSDP, this context manager is a no-op.
Expand Down Expand Up @@ -425,22 +428,29 @@ def _exec_func(
args = tree.map_structure(to_device, args)
kwargs = tree.map_structure(to_device, kwargs)

with torch.autocast(
device_type=self.module.device.type, dtype=self.module.dtype
):
res = (
getattr(self, func)(*args, **kwargs)
if isinstance(func, str)
else func(self, *args, **kwargs)
)
try:
with torch.autocast(
device_type=self.module.device.type, dtype=self.module.dtype
):
res = (
getattr(self, func)(*args, **kwargs)
if isinstance(func, str)
else func(self, *args, **kwargs)
)

# Needed to prevent GPU memory leak to the main process scheduling the workers
if isinstance(res, GenerateDecoderOnlyOutput):
res.past_key_values = None
res["past_key_values"] = None
# Needed to prevent GPU memory leak to the main process scheduling the workers
if isinstance(res, GenerateDecoderOnlyOutput):
res.past_key_values = None
res["past_key_values"] = None

to_cpu = partial(_move_tensor, device=torch.device("cpu"))
return tree.map_structure(to_cpu, res)
to_cpu = partial(_move_tensor, device=torch.device("cpu"))
return tree.map_structure(to_cpu, res)
except Exception as e:
# Re-raise the exception with traceback preserved. For some exceptions, Dask
# modifies or loses the original traceback when crossing process boundaries.
# RuntimeError preserves the traceback when using with_traceback() of original
# exception.
raise RuntimeError(str(e)).with_traceback(e.__traceback__) # noqa: B904

def __del__(self) -> None:
dist.destroy_process_group()
Expand Down Expand Up @@ -582,7 +592,7 @@ def get_cuda_visible_devices() -> int | None:
futures.append(future_op)
worker_ids.append(worker_id)

self.handlers = self.client_gather(futures)
self.actors: list[Actor] = self._client_gather(futures)
self.worker_ids = worker_ids

async def __call__(
Expand Down Expand Up @@ -633,28 +643,24 @@ def _submit_and_gather(
"""
if split_data:
chunker = TensorChunker(
num_chunks=len(self.handlers),
num_chunks=len(self.actors),
)
split_args, split_kwargs, dummy_flags = chunker.chunkify(*args, **kwargs)
else:
split_args = [args] * len(self.handlers)
split_kwargs = [kwargs] * len(self.handlers)
split_args = [args] * len(self.actors)
split_kwargs = [kwargs] * len(self.actors)

futures = [
self.client.submit(
handler._exec_func,
handler._exec_func(
func,
*args_i,
workers=[worker_id],
actor=True,
**kwargs_i,
)
for handler, worker_id, args_i, kwargs_i in zip(
self.handlers, self.worker_ids, split_args, split_kwargs, strict=True
self.actors, self.worker_ids, split_args, split_kwargs, strict=True
)
]
results = self.client_gather(futures)
results = cast("list[TReturn]", [res.result().result() for res in results])
results: list[TReturn] = self._client_gather(futures)

if split_data:
return chunker.dechunkify(results, dummy_flags)
Expand Down Expand Up @@ -767,29 +773,73 @@ def teardown(self) -> None:
if self._initialized:
self.client.shutdown()
self.cluster.close()
del self.client
del self.cluster
self._initialized = False

def __del__(self) -> None:
self.teardown()

def client_gather(self, futures):
@staticmethod
def _wrap_dask_future(dask_future: ActorFuture):
"""Converts a Dask ActorFuture into an awaitable asyncio.Future."""
loop = asyncio.get_running_loop()
return asyncio.ensure_future(loop.run_in_executor(None, dask_future.result))

def _client_gather(self, futures: list[ActorFuture]) -> list[Any]:
"""Gather results from futures, propagating exceptions as they arrive.

Unlike client.gather() which waits for all futures to complete before raising
any exceptions, this method processes futures as they complete and raises
exceptions immediately. This is crucial when using FSDP where workers may
be stuck waiting for each other where one worker crashes, causing long hangs.
be stuck waiting for each other when one worker crashes, causing long hangs.

Note: Dask Actors currently have an issue where they're not working properly with
dask.gather() and can cause blocking issues or hide worker errors. This implementation
works around those limitations.
"""
# Initialize a list to hold results
results = [None] * len(futures)
for completed_future, result in as_completed(
futures, with_results=True, raise_errors=True
):
# Find the index of the completed future
index = futures.index(completed_future)
# Store the result directly from as_completed
results[index] = result
return results

async def _gather_with_exception_handling(futures):
wrapped_futures = [self._wrap_dask_future(f) for f in futures]

try:
# Use asyncio.wait with FIRST_EXCEPTION instead of gather
done, pending = await asyncio.wait(
wrapped_futures, timeout=120, return_when=asyncio.FIRST_EXCEPTION
)

exceptions = []
for future in done:
exc = future.exception()
if exc:
exceptions.append(exc)
if exceptions:
if len(exceptions) == 1:
raise exceptions[0]
raise ExceptionGroup("Multiple actor exceptions", exceptions)

if pending:
pending_indices = sorted([
wrapped_futures.index(p) for p in pending
])
raise TimeoutError(
f"Tasks didn't complete within timeout. {len(pending)} out of {len(wrapped_futures)} "
f"still pending. Pending task indices: {pending_indices}"
)

return await asyncio.gather(*wrapped_futures)
except Exception as e:
logger.exception("Error in dask workers")
for f in wrapped_futures:
if not f.done():
f.cancel()
self.teardown()
# sys.exit(1) would wait for dask to finish, which can cause hanging
# when workers are in a deadlock. Use os._exit to force immediate termination
os._exit(1)

# Use distributed.utils.sync to run the async function in the current thread
return sync(self.client.loop, _gather_with_exception_handling, futures) # type: ignore[arg-type]


# Helpers
Expand Down
Loading