Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
trigger profiling on abort
Summary:
record the profile trace if the training process receives SIGABRT e.g. when Process Group watchdog aborts the process
  • Loading branch information
tushar00jain committed Oct 29, 2025
commit 88bb526c17f9499a103ebc54b3049946854c7515
11 changes: 6 additions & 5 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,13 @@ def train(self):
self.checkpointer.load(step=job_config.checkpoint.load_step)
logger.info(f"Training starts at step {self.step + 1}.")

torch_profiler = maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
)

with (
maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
) as torch_profiler,
maybe_enable_memory_snapshot(
job_config.profiling,
global_step=self.step,
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000


@contextlib.contextmanager
def maybe_enable_profiling(
profiling_config: ProfilingConfig,
*,
Expand Down Expand Up @@ -68,20 +67,20 @@ def trace_handler(prof):
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
elif torch.xpu.is_available():
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
with torch.profiler.profile(
torch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
gpu_device_profiled,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
on_trace_ready=trace_handler,
record_shapes=True,
) as torch_profiler:
torch_profiler.step_num = global_step
yield torch_profiler
)
torch_profiler.step_num = global_step
torch_profiler.start()
return torch_profiler
else:
torch_profiler = contextlib.nullcontext()
yield None
return None


@contextlib.contextmanager
Expand Down
34 changes: 26 additions & 8 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import ctypes
import importlib
import os
import signal
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional
Expand All @@ -32,8 +34,12 @@
maybe_enable_profiling,
)

c_globals = ctypes.CDLL(None) # POSIX


class Trainer(torch.distributed.checkpoint.stateful.Stateful):
torch_profiler: torch.profiler.profile | None = None

# core configs
job_config: JobConfig
parallel_dims: ParallelDims
Expand Down Expand Up @@ -580,13 +586,14 @@ def train(self):
if not self.ft_manager.enabled
else f"replica_{self.ft_manager.replica_id}"
)
self.torch_profiler = maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
leaf_folder=leaf_folder,
)

with (
maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
leaf_folder=leaf_folder,
) as torch_profiler,
maybe_enable_memory_snapshot(
job_config.profiling,
global_step=self.step,
Expand All @@ -610,6 +617,15 @@ def train(self):
),
),
):
if self.torch_profiler:

@ctypes.CFUNCTYPE(None, ctypes.c_int)
def sigabrt_handler(signal):
logger.info("SIGABRT received. Stopping profiler")
self.torch_profiler.export_chrome_trace("trace.json")

c_globals.signal(signal.SIGABRT, sigabrt_handler)

data_iterator = self.batch_generator(self.dataloader)
while self.should_continue_training():
self.step += 1
Expand All @@ -633,8 +649,8 @@ def train(self):
self.validator.validate(self.model_parts, self.step)

# signal the profiler that the next profiling step has started
if torch_profiler:
torch_profiler.step()
if self.torch_profiler:
self.torch_profiler.step()
if memory_profiler:
memory_profiler.step()

Expand Down Expand Up @@ -692,10 +708,12 @@ def close(self) -> None:
else:
trainer.train()
except Exception:
logger.info("Torchtitan training threw an exception")
if trainer:
trainer.close()
raise
else:
logger.info("Torchtitan training completed")
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")
Loading