From 7930ea9612f6ff0cfec5d61c181bc82f85b06fd0 Mon Sep 17 00:00:00 2001 From: Lei Date: Sat, 10 May 2025 12:16:12 -0700 Subject: [PATCH 1/2] Update export_dcp.py refactor: simplify model initialization and ensure at least one shard for model parameters --- scripts/export_dcp.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/scripts/export_dcp.py b/scripts/export_dcp.py index dd21e3d5..0e953582 100644 --- a/scripts/export_dcp.py +++ b/scripts/export_dcp.py @@ -135,11 +135,8 @@ def main(config: ExportConfig): logger.info("Getting model") model, model_config = get_model( - config.name_model, - config.type_model, - vocab_size=len(tokenizer), - seq_length=config.data.seq_length, - attn_fn=config.train.attn_fn, + config, + len(tokenizer) ) # Convert ZeroBand config to HuggingFace config @@ -162,7 +159,7 @@ def main(config: ExportConfig): logger.info("After load: %s", get_module_signature(model)) # Convert model to HuggingFace format - num_shards = int(sum(p.numel() for p in model.parameters()) / 1e9) + num_shards = max(1, int(sum(p.numel() for p in model.parameters()) / 1e9)) state_dict = model.state_dict() index_json = {} From 10c060b2014e1dcff258b8c26600f84572b4f694 Mon Sep 17 00:00:00 2001 From: Lei Date: Sat, 10 May 2025 12:18:46 -0700 Subject: [PATCH 2/2] Update checkpoint.py feat: use cloudpickle for saving and loading state in CkptManager --- src/zeroband/checkpoint.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index bdeb4d48..34f6fc46 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -1,3 +1,4 @@ +import cloudpickle from dataclasses import dataclass import gc import multiprocessing @@ -297,7 +298,7 @@ def _save(self, ckpt_path: str): state = {} state["optimizer"] = OuterOptimizerWrapper(self.diloco_offloaded_optimizer).state_dict() - torch.save(state, f) + torch.save(state, f, pickle_module=cloudpickle) data_path = os.path.join(ckpt_path, "data") self.save_data(data_path, self.dataloader, self.world_info.local_rank) @@ -320,7 +321,7 @@ def save_data(data_path: str, dataloader, local_rank: int): os.makedirs(data_path, exist_ok=True) with open(os.path.join(data_path, f"_{local_rank}.pt"), "wb") as f: state = {"data_loader": dataloader.state_dict()} - torch.save(state, f) + torch.save(state, f, pickle_module=cloudpickle) def _async_save_remote(self, ckpt_path: str, remote_ckpt_path: str, blocking: bool = True) -> None: """asyncronously rsync a ckpt folder to a remote location. Using fsspec to handle remote cloud storage without to install @@ -354,7 +355,7 @@ def wait_for_blocking_job(self): if self.world_info.local_rank == 0: if self.config.topk is not None: - delete_topk(self.logger, self.config.path, self.config.topk) + delete_topk(self._logger, self.config.path, self.config.topk) def _del__(self): self.wait_for_blocking_job() @@ -370,7 +371,7 @@ def _load_data(self, resume_ckpt_path: str): data_path = os.path.join(resume_ckpt_path, "data") with open(os.path.join(data_path, f"_{world_info.local_rank}.pt"), "rb") as f: - state = torch.load(f) + state = torch.load(f, pickle_module=cloudpickle) self.dataloader.load_state_dict(state["data_loader"]) @torch.no_grad() @@ -415,7 +416,7 @@ def load( if self.diloco_offloaded_optimizer: with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f: - rank_state_dict = torch.load(f) + rank_state_dict = torch.load(f, pickle_module=cloudpickle) opt_wrapper = OuterOptimizerWrapper(self.diloco_offloaded_optimizer) opt_wrapper.load_state_dict(rank_state_dict["optimizer"])