Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions scripts/export_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,8 @@ def main(config: ExportConfig):

logger.info("Getting model")
model, model_config = get_model(
config.name_model,
config.type_model,
vocab_size=len(tokenizer),
seq_length=config.data.seq_length,
attn_fn=config.train.attn_fn,
config,
len(tokenizer)
)

# Convert ZeroBand config to HuggingFace config
Expand All @@ -162,7 +159,7 @@ def main(config: ExportConfig):
logger.info("After load: %s", get_module_signature(model))

# Convert model to HuggingFace format
num_shards = int(sum(p.numel() for p in model.parameters()) / 1e9)
num_shards = max(1, int(sum(p.numel() for p in model.parameters()) / 1e9))
state_dict = model.state_dict()

index_json = {}
Expand Down
11 changes: 6 additions & 5 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cloudpickle
from dataclasses import dataclass
import gc
import multiprocessing
Expand Down Expand Up @@ -297,7 +298,7 @@ def _save(self, ckpt_path: str):
state = {}
state["optimizer"] = OuterOptimizerWrapper(self.diloco_offloaded_optimizer).state_dict()

torch.save(state, f)
torch.save(state, f, pickle_module=cloudpickle)

data_path = os.path.join(ckpt_path, "data")
self.save_data(data_path, self.dataloader, self.world_info.local_rank)
Expand All @@ -320,7 +321,7 @@ def save_data(data_path: str, dataloader, local_rank: int):
os.makedirs(data_path, exist_ok=True)
with open(os.path.join(data_path, f"_{local_rank}.pt"), "wb") as f:
state = {"data_loader": dataloader.state_dict()}
torch.save(state, f)
torch.save(state, f, pickle_module=cloudpickle)

def _async_save_remote(self, ckpt_path: str, remote_ckpt_path: str, blocking: bool = True) -> None:
"""asyncronously rsync a ckpt folder to a remote location. Using fsspec to handle remote cloud storage without to install
Expand Down Expand Up @@ -354,7 +355,7 @@ def wait_for_blocking_job(self):

if self.world_info.local_rank == 0:
if self.config.topk is not None:
delete_topk(self.logger, self.config.path, self.config.topk)
delete_topk(self._logger, self.config.path, self.config.topk)

def _del__(self):
self.wait_for_blocking_job()
Expand All @@ -370,7 +371,7 @@ def _load_data(self, resume_ckpt_path: str):
data_path = os.path.join(resume_ckpt_path, "data")

with open(os.path.join(data_path, f"_{world_info.local_rank}.pt"), "rb") as f:
state = torch.load(f)
state = torch.load(f, pickle_module=cloudpickle)
self.dataloader.load_state_dict(state["data_loader"])

@torch.no_grad()
Expand Down Expand Up @@ -415,7 +416,7 @@ def load(

if self.diloco_offloaded_optimizer:
with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)
rank_state_dict = torch.load(f, pickle_module=cloudpickle)

opt_wrapper = OuterOptimizerWrapper(self.diloco_offloaded_optimizer)
opt_wrapper.load_state_dict(rank_state_dict["optimizer"])
Expand Down