Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Make checkpoint manager extendable for other StorageWriters
  • Loading branch information
dimdi-y committed Jan 10, 2025
commit 5f3afedeeb25e4dd83118e26110f9c142eaff2f5
113 changes: 89 additions & 24 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dataclasses import dataclass, field
from io import BytesIO
from multiprocessing import get_context
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -25,6 +25,7 @@
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.storage import StorageReader, StorageWriter
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
Expand Down Expand Up @@ -105,7 +106,7 @@ class SaveDone:
pass


def checkpoint_mp(recv, send):
def checkpoint_mp(recv, send, storage_writer):
init_logger()
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
Expand All @@ -124,7 +125,11 @@ def checkpoint_mp(recv, send):
assert isinstance(obj, tuple)
begin = time.monotonic()
state, checkpoint_id = obj
dcp.save(state, checkpoint_id=checkpoint_id)
dcp.save(
state,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
)
logger.info(
"Finish saving the checkpoint in the background process in "
f"{time.monotonic() - begin:.2f} seconds."
Expand All @@ -143,6 +148,8 @@ def __init__(
lr_schedulers: SchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
storage_reader: Optional[StorageReader] = None,
storage_writer: Optional[StorageWriter] = None,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
Expand Down Expand Up @@ -187,7 +194,9 @@ def __init__(
)
self.states.update(lr_schedulers.get_lr_scheduler_state())

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.folder = self._get_folder_from_job_config(job_config)
self.storage_reader = storage_reader
self.storage_writer = storage_writer
self.interval_type = (
IntervalType.SECONDS
if ckpt_config.interval_type == "seconds"
Expand Down Expand Up @@ -219,6 +228,7 @@ def __init__(
args=(
self.mp_queue_send,
self.mp_queue_recv,
self.storage_writer,
),
daemon=True,
)
Expand All @@ -234,6 +244,16 @@ def __init__(
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
)

def _get_folder_from_job_config(self, job_config: JobConfig) -> str:
"""Construct self.folder from job config.

Can be overriden for compatibility with custom storage_reader/storage_writer
"""
return os.path.join(
job_config.job.dump_folder,
job_config.checkpoint.folder,
)

def __del__(self):
if self.enable_checkpoint and self.mp and self.mp.is_alive():
self.mp_queue_send.put(Terminate())
Expand All @@ -243,6 +263,11 @@ def reset(self) -> None:
self.begin_time = time.monotonic()

def _create_checkpoint_id(self, step: int) -> str:
"""Convert step to checkpoint id acceptable by dcp.save

Can be overriden for compatibility with custom storage_reader/storage_writer
"""

return os.path.join(self.folder, f"step-{step}")

def _save_last_step(self, curr_step: int) -> None:
Expand Down Expand Up @@ -274,7 +299,11 @@ def _save_last_step(self, curr_step: int) -> None:
else:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")

dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
dcp.save(
self.states,
checkpoint_id=self._create_checkpoint_id(curr_step),
storage_writer=self.storage_writer,
)
self.reset()

def _should_save(self, curr_step: int, force: bool = False) -> bool:
Expand Down Expand Up @@ -369,10 +398,17 @@ def save(self, curr_step: int, force: bool = False) -> None:
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
self.states,
checkpoint_id=checkpoint_id,
process_group=self.pg,
storage_writer=self.storage_writer,
)
else:
dcp.save(self.states, checkpoint_id=checkpoint_id)
dcp.save(
self.states,
checkpoint_id=checkpoint_id,
storage_writer=self.storage_writer,
)
self.reset()
self._purge_stale_checkpoints()

Expand Down Expand Up @@ -402,24 +438,53 @@ def sync_func():
sync_func()
self.staging = False

def _check_checkpoint_exitsts(self, step: int) -> bool:
Copy link
Contributor

@lessw2020 lessw2020 Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR!
I wanted to note that this should be "exists" rather than the current spelling, "exitsts"...thus _check_checkpoint_exists.
In a way, it's a nit but it's also important as misspelled apis can create confusion (have personally hit such things where type in correct spelling but api has mis-spelling and not obvious failures result.
There are 2 calls here to this function that would need to also be update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor point - '_verify_checkpoint_exists' would be smoother name (vs check_checkpoint).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, thanks! Will surely fix this.

"""Check if a checkpoint has been fully written for the corresponding step

Can be overriden for compatibility with custom storage_reader/storage_writer
"""

checkpoint_id = self._create_checkpoint_id(step)
metadata_probe = os.path.join(checkpoint_id, ".metadata")
return os.path.isfile(metadata_probe)

def _discover_checkpointed_steps(self) -> List[int]:
"""List steps that have their corresponding directories created

Can be overriden for compatibility with custom storage_reader/storage_writer
"""
if not os.path.isdir(self.folder):
return []

discovered_steps = []
for filename in os.listdir(self.folder):
match = re.search(r"step-(\d+)", filename)
if not match:
continue
step = int(match.group(1))
discovered_steps.append(step)
if not discovered_steps:
return None
return discovered_steps

def _find_last_saved_step(self) -> Optional[int]:
all_steps = self._discover_checkpointed_steps()
fully_written_steps = list(filter(self._check_checkpoint_exitsts, all_steps))

return max(fully_written_steps, default=None)

def load(self, step: int = -1) -> bool:
if not self.enable_checkpoint:
return False
if not os.path.isdir(self.folder):
return False
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)):
return False

if step == -1:
step_counts = []
for filename in os.listdir(self.folder):
match = re.search(r"step-(\d+)", filename)
metadata_probe = os.path.join(self.folder, filename, ".metadata")
if match and os.path.isfile(metadata_probe):
step_counts.append(int(match.group(1)))
if not step_counts:
last_step = self._find_last_saved_step()
if last_step is None:
return False
step = max(step_counts)
step = last_step

if not self._check_checkpoint_exitsts(step):
return False

# We won't have optimizer states to load, if we are loading a seed checkpoint
states = {"model": self.states["model"]} if step == 0 else self.states
Expand All @@ -437,6 +502,7 @@ def load(self, step: int = -1) -> bool:
dcp.load(
states,
checkpoint_id=self._create_checkpoint_id(step),
storage_reader=self.storage_reader,
)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand All @@ -448,11 +514,10 @@ def load(self, step: int = -1) -> bool:

def _purge_stale_checkpoints(self):
if self.keep_latest_k > 0:
discovered_checkpoints = []
for filename in os.listdir(self.folder):
match = re.search(r"step-(\d+)", filename)
path = os.path.join(self.folder, filename)
discovered_checkpoints.append((int(match.group(1)), path))
discovered_checkpoints = [
(step, self._create_checkpoint_id(step))
for step in self._discover_checkpointed_steps()
]

discovered_checkpoints.sort()
to_delete = discovered_checkpoints[: -1 * self.keep_latest_k]
Expand Down