Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix setting ft state dicts when ft checkpointing is disabled
Summary:
- when ft dataloader checkpointing is disabled, we also don't set the ft state
- make it so that when ft checkpointing is disabled, we still set the state dict so that model, optimizer etc. can be recovered from a different replica
  • Loading branch information
tushar00jain committed Oct 29, 2025
commit 22a1a9a0ff0d4d86e91d3e8113c941b8a9c8bb7b
59 changes: 34 additions & 25 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,25 @@ def __init__(
self.enable = checkpoint_config.enable
self.load_only = checkpoint_config.load_only

self.states = states
self.states.update(
{
MODEL: ModelWrapper(model_parts),
OPTIMIZER: optimizers,
DATALOADER: dataloader,
LR_SCHEDULER: lr_schedulers,
}
)

self.ft_manager = (
ft_manager.manager
if ft_manager
and ft_manager.enabled
and checkpoint_config.enable_ft_dataloader_checkpoints
else None
ft_manager.manager if ft_manager and ft_manager.enabled else None
)

if ft_manager and ft_manager.enabled and not self.ft_manager:
self.enable_ft_dataloader_checkpoints = (
self.ft_manager and checkpoint_config.enable_ft_dataloader_checkpoints
)

if self.ft_manager and not self.enable_ft_dataloader_checkpoints:
logger.warn(
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
Expand Down Expand Up @@ -229,20 +239,11 @@ def load_state_dict(state_dict):
async_mode = checkpoint_config.async_mode.lower()
self.enable_staging = (
self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager
) or self.enable_ft_dataloader_checkpoints

if not self.enable and self.ft_manager is None:
if not self.enable and not self.enable_ft_dataloader_checkpoints:
return

self.states = states
self.states.update(
{
MODEL: ModelWrapper(model_parts),
OPTIMIZER: optimizers,
DATALOADER: dataloader,
LR_SCHEDULER: lr_schedulers,
}
)
self.ft_states = {DATALOADER: dataloader}

self.staging = False
Expand Down Expand Up @@ -279,7 +280,7 @@ def load_state_dict(state_dict):
if (
async_mode == AsyncMode.ASYNC
or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
or self.ft_manager
or self.enable_ft_dataloader_checkpoints
):
self.pg = dist.new_group(backend="gloo")

Expand Down Expand Up @@ -480,14 +481,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
None
"""

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
self._ft_save(curr_step)

if not self._should_save(curr_step, last_step):
return

begin = time.monotonic()
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
if not self.enable_ft_dataloader_checkpoints or (
self.ft_manager and self.ft_manager.participating_rank() == 0
):
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
Expand Down Expand Up @@ -530,7 +533,8 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
elif self.ft_manager:
elif self.enable_ft_dataloader_checkpoints:
assert self.ft_manager is not None
logger.info(
"Replica %d doesn't save checkpoint.",
self.ft_manager.participating_rank(),
Expand All @@ -551,7 +555,7 @@ def load(self, step: int = -1) -> bool:
bool: Whether the checkpoint was loaded successfully.
"""

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
self._ft_load()

if not self.enable:
Expand Down Expand Up @@ -749,7 +753,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:

states_to_load = self._flattened_model_states_sd(states_to_load)

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
states_to_load.pop(DATALOADER)

return states_to_load
Expand Down Expand Up @@ -805,7 +809,9 @@ def _async_wait(self) -> None:
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
if self.save_future is not None:
self.save_future.result()
elif self.async_mode == AsyncMode.ASYNC or self.ft_manager is not None:
elif (
self.async_mode == AsyncMode.ASYNC or self.enable_ft_dataloader_checkpoints
):
if self.save_future is not None:
self.save_future.result()
self.save_future = None
Expand All @@ -820,7 +826,10 @@ def _purge_stale_checkpoints(self):
self.keep_latest_k > 0
and dist.get_rank() == 0
and os.path.isdir(self.folder)
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
and (
not self.enable_ft_dataloader_checkpoints
or (self.ft_manager and self.ft_manager.participating_rank() == 0)
)
):
discovered_checkpoints = []
for filename in os.listdir(self.folder):
Expand Down
Loading