diff --git a/docs/examples/config.rst b/docs/examples/config.rst index a3f910b3645..076ab24df18 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -137,7 +137,11 @@ Actor/Rollout/Reference Policy optimizer_offload: False fsdp_size: -1 checkpoint: - contents: ['model', 'optimizer', 'extra'] + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} ref: fsdp_config: param_offload: False @@ -267,9 +271,11 @@ Actor/Rollout/Reference Policy - ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor - - ``contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint. + - ``save_contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint. The extra information includes Rng states currently, FSDP supported lr_scheduler, and Megatron opt_param_scheduler will coming soon. - We do not store hf_model in checkpoint by default, but we provide a tool in `scripts/model_merge.py` to convert checkpoint format to hf format. + We do not store hf_model in checkpoint by default, but we provide a tool in ``scripts/model_merge.py`` to convert checkpoint format to hf format. + + - ``load_contents``: The contents to load in the checkpoint, you can specify different checkpoint loading contents. By default, it is the same with ``save_checkpoint``. **Reference Model** diff --git a/recipe/spin/fsdp_workers.py b/recipe/spin/fsdp_workers.py index 13d68ccd196..36fa9ec7ec0 100644 --- a/recipe/spin/fsdp_workers.py +++ b/recipe/spin/fsdp_workers.py @@ -41,21 +41,20 @@ from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: - device_mesh = init_device_mesh('cuda', - mesh_shape=(world_size // fsdp_size, fsdp_size), - mesh_dim_names=['ddp', 'fsdp']) + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]) return device_mesh def get_sharding_strategy(device_mesh): from torch.distributed.fsdp import ShardingStrategy + if device_mesh.ndim == 1: sharding_strategy = ShardingStrategy.FULL_SHARD elif device_mesh.ndim == 2: @@ -64,18 +63,20 @@ def get_sharding_strategy(device_mesh): raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") return sharding_strategy + class SPINRolloutRefWorker(ActorRolloutRefWorker): - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor + # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get('external_lib', None)) + import_external_libs(self.config.model.get("external_lib", None)) from omegaconf import OmegaConf - override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) - use_remove_padding = self.config.model.get('use_remove_padding', False) + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + + use_remove_padding = self.config.model.get("use_remove_padding", False) use_fused_kernels = self.config.model.get('use_fused_kernels', False) if self._is_actor or self._is_rollout or self._is_ref: @@ -93,17 +94,18 @@ def init_model(self): override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, - enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False), - trust_remote_code=self.config.model.get('trust_remote_code', False), - use_liger=self.config.model.get('use_liger', False), - role='actor') + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + ) # get the original unwrapped module self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage('After offload actor optimizer during init', logger=logger) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) # load from checkpoint if self._is_actor or self._is_ref: OmegaConf.set_struct(self.config.actor, True) @@ -115,8 +117,7 @@ def init_model(self): actor_optimizer=self.actor_optimizer) if self._is_rollout: - self.rollout, self.rollout_sharding_manager = self._build_rollout( - trust_remote_code=self.config.model.get('trust_remote_code', False)) + self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) if self._is_ref: self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, @@ -134,23 +135,12 @@ def init_model(self): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.actor_module_fsdp, - optimizer=self.actor.actor_optimizer, - lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents) - + self.checkpoint_manager = FSDPCheckpointManager(model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.actor.checkpoint) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.actor_module_fsdp, - optimizer=self.actor.actor_optimizer, - lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents) - + self.checkpoint_manager = FSDPCheckpointManager(model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.actor.checkpoint) + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): assert self._is_ref @@ -159,17 +149,17 @@ def compute_ref_log_prob(self, data: DataProto): data = data.to(torch.cuda.current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu - data.meta_info['micro_batch_size'] = micro_batch_size - data.meta_info['temperature'] = self.config.rollout.temperature - data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.ref.log_prob_use_dynamic_bsz + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output = self.ref_policy.compute_log_prob(data=data) - output = DataProto.from_dict(tensors={'ref_log_prob': output}) + output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -187,19 +177,18 @@ def compute_log_prob(self, data: DataProto): # Support all hardwares data = data.to(torch.cuda.current_device()) # we should always recompute old_log_probs when it is HybridEngine - data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz - data.meta_info['temperature'] = self.config.rollout.temperature + data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output = self.actor.compute_log_prob(data=data) - output = DataProto.from_dict(tensors={'old_log_probs': output}, - meta_info={'temperature': self.config.rollout.temperature}) + output = DataProto.from_dict(tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}) output = self.ulysses_sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -209,7 +198,7 @@ def compute_log_prob(self, data: DataProto): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage('After compute_log_prob', logger=logger) + log_gpu_memory_usage("After compute_log_prob", logger=logger) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -222,9 +211,9 @@ def update_actor_dpo(self, data: DataProto): # Support all hardwares data = data.to(torch.cuda.current_device()) - assert self._is_actor # Make sure this worker has the actor role + assert self._is_actor # Make sure this worker has the actor role if self.actor is None: - raise RuntimeError("Actor instance (self.actor) not initialized in worker.") + raise RuntimeError("Actor instance (self.actor) not initialized in worker.") # --- FSDP State Management --- if self._is_offload_param: @@ -232,36 +221,39 @@ def update_actor_dpo(self, data: DataProto): if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) - log_gpu_memory_usage('Before update policy (DPO via PPO path)', logger=logger) + log_gpu_memory_usage("Before update policy (DPO via PPO path)", logger=logger) # --- Ulysses Sharding (if used) --- with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) # --- Call the core update method (now containing DPO logic) --- - with Timer(name='update_policy_dpo_via_ppo', logger=None) as timer: # Use a distinct timer name + with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name # Calls the modified update_policy method - metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION + metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION delta_time = timer.last # --- Add Performance Metrics --- # MFU calculation might be less accurate/meaningful here for DPO - metrics['perf/approx_tokens_processed'] = torch.sum(data.batch.get('attention_mask', torch.tensor(0))).item() # Approx tokens - metrics['perf/max_memory_allocated_gb'] = torch.cuda.max_memory_allocated() / (1024**3) - metrics['perf/max_memory_reserved_gb'] = torch.cuda.max_memory_reserved() / (1024**3) - metrics['perf/cpu_memory_used_gb'] = psutil.virtual_memory().used / (1024**3) + metrics["perf/approx_tokens_processed"] = torch.sum(data.batch.get("attention_mask", torch.tensor(0))).item() # Approx tokens + metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + global_num_tokens = data.meta_info["global_token_num"] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size # --- LR Scheduler Step --- self.actor_lr_scheduler.step() lr = self.actor_lr_scheduler.get_last_lr()[0] - metrics['actor/lr'] = lr + metrics["actor/lr"] = lr - log_gpu_memory_usage('After update policy (DPO via PPO path)', logger=logger) + log_gpu_memory_usage("After update policy (DPO via PPO path)", logger=logger) # --- Prepare Output --- - output = DataProto(meta_info={'metrics': metrics}) + output = DataProto(meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) - output = output.to('cpu') + output = output.to("cpu") # --- FSDP State Management (Offload) --- if self._is_offload_param: @@ -270,7 +262,6 @@ def update_actor_dpo(self, data: DataProto): offload_fsdp_optimizer(optimizer=self.actor_optimizer) return output - # TODO(sgm): we may need to extract it to dp_reward_model.py @@ -282,6 +273,7 @@ class RewardModelWorker(Worker): def __init__(self, config): super().__init__() import torch.distributed + if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") self.config = config @@ -294,16 +286,14 @@ def __init__(self, config): self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', - mesh_shape=(dp, self.ulysses_sequence_parallel_size), - mesh_dim_names=['dp', 'sp']) + self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - self.use_remove_padding = self.config.model.get('use_remove_padding', False) + self.use_remove_padding = self.config.model.get("use_remove_padding", False) # normalize config if self.config.micro_batch_size is not None: @@ -324,29 +314,24 @@ def _build_model(self, config): else: self._do_switch_chat_template = True input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) - self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, - trust_remote_code=config.model.get('trust_remote_code', False)) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get('trust_remote_code', False)) + self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False)) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) - trust_remote_code = config.model.get('trust_remote_code', False) + trust_remote_code = config.model.get("trust_remote_code", False) model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) model_config.num_labels = 1 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, - mesh=self.device_mesh) + init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") model_config.classifier_dropout = 0.0 - reward_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + reward_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, config=model_config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", trust_remote_code=trust_remote_code) - if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1: + if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) reward_module.to(torch.bfloat16) @@ -366,14 +351,15 @@ def _build_model(self, config): sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), forward_prefetch=False, - device_mesh=self.device_mesh) + device_mesh=self.device_mesh, + ) return reward_module @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get('external_lib', None)) + import_external_libs(self.config.model.get("external_lib", None)) self.reward_module = self._build_model(config=self.config) def _forward_micro_batch(self, micro_batch): @@ -381,49 +367,36 @@ def _forward_micro_batch(self, micro_batch): from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs - with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): - input_ids = micro_batch['input_ids'] + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape - attention_mask = micro_batch['attention_mask'] - position_ids = micro_batch['position_ids'] + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ - position_ids_rmpad, \ - sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False) # prevent model thinks we are generating + output = self.reward_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False) # prevent model thinks we are generating reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: - reward_rmpad = gather_outpus_and_unpad(reward_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) + reward_rmpad = gather_outpus_and_unpad(reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) # pad it back rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) else: - output = self.reward_module(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False) + output = self.reward_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False) rm_score = output.logits # (batch_size, seq_len, 1) rm_score = rm_score.squeeze(-1) @@ -435,9 +408,9 @@ def _forward_micro_batch(self, micro_batch): def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): batch_size = data.batch.batch_size[0] # expand as token_level_reward - attention_mask = data.batch['attention_mask'] - position_ids = data.batch['position_ids'] - response_length = data.batch['responses'].shape[-1] + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] + response_length = data.batch["responses"].shape[-1] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores @@ -448,7 +421,7 @@ def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): return token_level_scores def _switch_chat_template(self, data: DataProto): - src_max_length = data.batch['attention_mask'].shape[-1] + src_max_length = data.batch["attention_mask"].shape[-1] src_tokenizer = self.input_tokenizer target_tokenizer = self.tokenizer @@ -458,44 +431,43 @@ def _switch_chat_template(self, data: DataProto): for i in range(data.batch.batch_size[0]): # extract raw prompt - if isinstance(data.non_tensor_batch['raw_prompt'][i], list): - chat: list = data.non_tensor_batch['raw_prompt'][i] + if isinstance(data.non_tensor_batch["raw_prompt"][i], list): + chat: list = data.non_tensor_batch["raw_prompt"][i] else: - chat: list = data.non_tensor_batch['raw_prompt'][i].tolist() + chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() # extract response - response_ids = data.batch['responses'][i] + response_ids = data.batch["responses"][i] response_length = response_ids.shape[-1] - valid_response_length = data.batch['attention_mask'][i][-response_length:].sum() + valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode response = src_tokenizer.decode(valid_response_ids) # remove bos and eos - response = response.replace(src_tokenizer.eos_token, '') + response = response.replace(src_tokenizer.eos_token, "") - chat.append({'role': 'assistant', 'content': response}) + chat.append({"role": "assistant", "content": response}) - prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, - add_generation_prompt=False, - tokenize=False) + prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, add_generation_prompt=False, tokenize=False) if self.rank == 0 and i == 0: # for debugging purpose - print(f'Switch template. chat: {prompt_with_chat_template}') + print(f"Switch template. chat: {prompt_with_chat_template}") # the maximum length is actually determined by the reward model itself - max_length = self.config.get('max_length', src_max_length) + max_length = self.config.get("max_length", src_max_length) if max_length is None: max_length = src_max_length - model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors='pt', add_special_tokens=False) + model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) input_ids, attention_mask = verl_F.postprocess_data( - input_ids=model_inputs['input_ids'], - attention_mask=model_inputs['attention_mask'], + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], max_length=max_length, pad_token_id=target_tokenizer.pad_token_id, left_pad=False, # right padding - truncation=self.config.get('truncation', 'right')) # truncate from the right + truncation=self.config.get("truncation", "right"), + ) # truncate from the right rm_input_ids.append(input_ids) rm_attention_mask.append(attention_mask) @@ -505,7 +477,7 @@ def _switch_chat_template(self, data: DataProto): rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - rm_inputs = {'input_ids': rm_input_ids, 'attention_mask': rm_attention_mask, 'position_ids': rm_position_ids} + rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} return DataProto.from_dict(rm_inputs) @@ -514,19 +486,16 @@ def compute_rm_score(self, data: DataProto): import itertools from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + # Support all hardwares data = data.to(torch.cuda.current_device()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: - rm_input_ids = data.batch['input_ids'] - rm_attention_mask = data.batch['attention_mask'] - rm_position_ids = data.batch['position_ids'] - rm_inputs = { - 'input_ids': rm_input_ids, - 'attention_mask': rm_attention_mask, - 'position_ids': rm_position_ids - } + rm_input_ids = data.batch["input_ids"] + rm_attention_mask = data.batch["attention_mask"] + rm_position_ids = data.batch["position_ids"] + rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} rm_data = DataProto.from_dict(rm_inputs) # Support all hardwares @@ -557,12 +526,12 @@ def compute_rm_score(self, data: DataProto): token_level_scores = self._expand_to_token_level(data, scores) # Note that this is only the scores, may not be the final rewards used to train RL - output = DataProto.from_dict(tensors={'rm_scores': token_level_scores}) + output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) output = self.ulysses_sharding_manager.postprocess_data(data=output) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module self.reward_module._handle.reshard(True) - output = output.to('cpu') + output = output.to("cpu") return output diff --git a/recipe/sppo/sppo_worker.py b/recipe/sppo/sppo_worker.py index 07bc1115237..d2791f94079 100644 --- a/recipe/sppo/sppo_worker.py +++ b/recipe/sppo/sppo_worker.py @@ -117,5 +117,5 @@ def init_model(self): optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents, + checkpoint_contents=self.config.actor.checkpoint, ) diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index 6d07870744f..512f3455bb9 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -99,7 +99,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \ actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \ actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \ - actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \ + actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \ actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ diff --git a/tests/e2e/run_ppo_trainer_megatron.sh b/tests/e2e/run_ppo_trainer_megatron.sh index d85f68623c9..54ce1ed7758 100644 --- a/tests/e2e/run_ppo_trainer_megatron.sh +++ b/tests/e2e/run_ppo_trainer_megatron.sh @@ -2,6 +2,8 @@ set -xeuo pipefail export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping +export VERL_LOGGING_LEVEL=INFO +export VERL_PPO_LOGGING_LEVEL=INFO NUM_GPUS=${NUM_GPUS:-8} @@ -155,7 +157,7 @@ for ENGINE in "${ENGINES[@]}"; do actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.checkpoint.contents=$CHECKPOINT_CONTENTS \ + actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_CONTENTS \ actor_rollout_ref.rollout.name="${ENGINE}" \ actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ @@ -188,7 +190,7 @@ for ENGINE in "${ENGINES[@]}"; do critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \ critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - critic.checkpoint.contents=$CHECKPOINT_CONTENTS \ + critic.checkpoint.save_contents=$CHECKPOINT_CONTENTS \ reward_model.enable=True \ reward_model.model.path="${MODEL_PATH}" \ reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py index 912de7822f1..ce6b5640e3f 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py @@ -55,7 +55,8 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py index 18a1cf9ce90..0ad5b992a19 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py @@ -55,7 +55,8 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py index f5ffb4a8fd1..798597db064 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py @@ -21,7 +21,8 @@ from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP -from verl.utils.megatron_utils import print_rank_0, unwrap_model +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): diff --git a/verl/models/mcore/loader.py b/verl/models/mcore/loader.py index 2c878466140..00532b4310f 100644 --- a/verl/models/mcore/loader.py +++ b/verl/models/mcore/loader.py @@ -56,7 +56,8 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() @@ -382,7 +383,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, f"{layer_name}.input_layernorm.weight", ) - + if f"{layer_name}.self_attn.q_norm.weight" in state_dict: _broadcast_tensor( sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None, diff --git a/verl/models/mcore/saver.py b/verl/models/mcore/saver.py index 5d1037681f4..a7367e952ba 100644 --- a/verl/models/mcore/saver.py +++ b/verl/models/mcore/saver.py @@ -22,7 +22,8 @@ from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP -from verl.utils.megatron_utils import print_rank_0, unwrap_model +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0): diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py index 7d15a28bb19..ffef31aaea2 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py @@ -53,7 +53,8 @@ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py index 8f581176ce3..5a5d511f54d 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py @@ -53,7 +53,8 @@ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py index 11cba17b11b..bdcec847922 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py @@ -21,7 +21,8 @@ from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP -from verl.utils.megatron_utils import print_rank_0, unwrap_model +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 659243f0450..08e36679c1e 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -101,7 +101,11 @@ actor_rollout_ref: save_path: null # the path to save the profile result load_weight: True checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} ref: strategy: megatron use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} @@ -255,7 +259,10 @@ critic: kl_coef: 0.001 loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + load_contents: ${critic.checkpoint.save_contents} reward_model: enable: False diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 39d63b7642e..f53375937d9 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -217,7 +217,10 @@ actor_rollout_ref: # What to include in saved checkpoints # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - contents: ['model', 'optimizer', 'extra'] + save_contents: ['model', 'optimizer', 'extra'] + + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} # optimizer configs optim: @@ -618,7 +621,8 @@ critic: # What to include in saved checkpoints # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - contents: ['model', 'optimizer', 'extra'] + save_contents: ['model', 'optimizer', 'extra'] + load_contents: ${critic.checkpoint.save_contents} # configs for the reward model reward_model: diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py index 31a2bacf924..b7a7d97c459 100644 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -11,16 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import random import shutil import tempfile -from typing import Optional, Union +from typing import Union import numpy as np import torch import torch.distributed from filelock import FileLock +from omegaconf import DictConfig from transformers import PreTrainedTokenizer, ProcessorMixin from verl.utils.device import is_cuda_available, is_npu_available @@ -47,10 +49,14 @@ def __init__( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None, processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, - checkpoint_contents: Optional[list] = None, + checkpoint_contents: DictConfig = None, ): - if checkpoint_contents is None: - checkpoint_contents = ["model", "optimizer", "extra"] + checkpoint_load_contents = checkpoint_contents.get("load_contents", None) if checkpoint_contents else None + checkpoint_save_contents = checkpoint_contents.get("save_contents", None) if checkpoint_contents else None + if checkpoint_load_contents is None: + checkpoint_load_contents = ["model", "optimizer", "extra"] + if checkpoint_save_contents is None: + checkpoint_save_contents = ["model", "optimizer", "extra"] self.previous_global_step = None self.previous_saved_paths = [] @@ -58,11 +64,61 @@ def __init__( self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.processing_class = processing_class - self.checkpoint_contents = checkpoint_contents + self.checkpoint_load_contents = checkpoint_load_contents + self.checkpoint_save_contents = checkpoint_save_contents self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() + @property + def should_save_model(self) -> bool: + """ + Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved. + """ + return "model" in self.checkpoint_save_contents + + @property + def should_save_optimizer(self) -> bool: + """ + Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved. + """ + return "optimizer" in self.checkpoint_save_contents + + @property + def should_save_extra(self) -> bool: + """ + Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved. + """ + return "extra" in self.checkpoint_save_contents + + @property + def should_save_hf_model(self) -> bool: + """ + Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf model and saved. + """ + return "hf_model" in self.checkpoint_save_contents + + @property + def should_load_model(self) -> bool: + """ + Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded. + """ + return "model" in self.checkpoint_load_contents + + @property + def should_load_optimizer(self) -> bool: + """ + Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded. + """ + return "optimizer" in self.checkpoint_load_contents + + @property + def should_load_extra(self) -> bool: + """ + Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded. + """ + return "extra" in self.checkpoint_load_contents + def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False): raise NotImplementedError diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index f5980129e91..8d59c5d5a3b 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import warnings from typing import Optional, Union @@ -19,6 +20,7 @@ import torch import torch.distributed from accelerate import init_empty_weights +from omegaconf import DictConfig from torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin @@ -26,9 +28,14 @@ from verl.utils.device import is_cuda_available from verl.utils.fs import copy_to_local, is_non_local from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx +from verl.utils.logger import log_with_rank from .checkpoint_manager import BaseCheckpointManager +# Setup logging +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + class FSDPCheckpointManager(BaseCheckpointManager): """ @@ -44,26 +51,24 @@ class FSDPCheckpointManager(BaseCheckpointManager): lr_scheduler (LRScheduler): Learning-rate scheduler. processing_class (PreTrainedTokenizer or ProcessorMixin, optional): Pre-/post-processing artifact handler. - checkpoint_contents (list[str], optional): - Components to include; must contain 'model', 'optimizer', 'extra'. + checkpoint_contents DictConfig: Configuration for checkpoint contents. + - 'load': Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra']. + - 'save': Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra']. """ def __init__( self, model: FSDP, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + optimizer: Optional[torch.optim.Optimizer] = None, + lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, - checkpoint_contents: Optional[list] = None, + checkpoint_contents: DictConfig = None, **kwargs, ): - if checkpoint_contents is None: - checkpoint_contents = ["model", "optimizer", "extra"] if processing_class is None: assert "tokenizer" in kwargs, "tokenizer or processor must be provided" warnings.warn("`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2) processing_class = kwargs.pop("tokenizer") - assert "model" in checkpoint_contents and "optimizer" in checkpoint_contents and "extra" in checkpoint_contents, f"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}" super().__init__( model, @@ -89,42 +94,55 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte if local_path is None: return + # check if the checkpoint_load_contents is valid + if self.should_load_model: + assert self.model is not None, "model must be provided when checkpoint_contents.load includes ['model']" + if self.should_load_optimizer: + assert self.optimizer is not None, "optimizer must be provided when checkpoint_contents.load includes ['optimizer']" + # every rank download its own checkpoint - remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") - remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") - remote_extra_state_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") - print(f"[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}") - local_model_path = copy_to_local(remote_model_path) - local_optim_path = copy_to_local(remote_optim_path) - local_extra_state_path = copy_to_local(remote_extra_state_path) - - model_state_dict = torch.load(local_model_path, weights_only=False) - optimizer_state_dict = torch.load(local_optim_path, weights_only=False) - extra_state_dict = torch.load(local_extra_state_path, weights_only=False) - - if del_local_after_load: + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) if self.should_load_model else None + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) if self.should_load_optimizer else None + with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): + if self.should_load_model: + remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") + local_model_path = copy_to_local(remote_model_path) + model_state_dict = torch.load(local_model_path, weights_only=False) + self.model.load_state_dict(model_state_dict) + log_with_rank(f"Loaded model from {remote_model_path}", rank=self.rank, logger=logger) + + if self.should_load_optimizer: + remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") + local_optim_path = copy_to_local(remote_optim_path) + optimizer_state_dict = torch.load(local_optim_path, weights_only=False) + self.optimizer.load_state_dict(optimizer_state_dict) + log_with_rank(f"Loaded optimizer from {remote_optim_path}", rank=self.rank, logger=logger) + + if self.should_load_extra: + remote_extra_state_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") + local_extra_state_path = copy_to_local(remote_extra_state_path) + extra_state_dict = torch.load(local_extra_state_path, weights_only=False) + # recover random state + if "rng" in extra_state_dict: + # 'rng' may not exist for backward compatibility + self.load_rng_state(extra_state_dict["rng"]) + log_with_rank(f"Loaded rng from {remote_extra_state_path}", rank=self.rank, logger=logger) + + lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] + if lr_scheduler_state_dict is not None and self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + log_with_rank(f"Loaded lr_scheduler from {remote_extra_state_path}", rank=self.rank, logger=logger) + + if self.rank == 0 and del_local_after_load: try: os.remove(local_model_path) if is_non_local(local_model_path) else None os.remove(local_optim_path) if is_non_local(local_optim_path) else None os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None except Exception as e: - print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored") - - lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] - - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): - self.model.load_state_dict(model_state_dict) - if self.optimizer is not None: - self.optimizer.load_state_dict(optimizer_state_dict) - # recover random state - if "rng" in extra_state_dict: - # 'rng' may not exist for backward compatibility - self.load_rng_state(extra_state_dict["rng"]) + log_with_rank(f"remove local resume ckpt file after loading failed, exception {e} will be ignored", rank=self.rank, logger=logger) - if self.lr_scheduler is not None: - self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + # wait for everyone to load checkpoints + torch.distributed.barrier() def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): """ @@ -150,8 +168,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i # record the previous global step self.previous_global_step = global_step - # remove previous local_path - if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep: + # remove previous local_path, only rank 0 should do this + if self.rank == 0 and max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep: keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) self.previous_saved_paths = self.previous_saved_paths[keep_start:] @@ -159,30 +177,40 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i local_path = self.local_mkdir(local_path) torch.distributed.barrier() + # check if the checkpoint_save_contents is valid + if self.should_save_model: + assert self.model is not None, "model must be provided when checkpoint_contents.save includes ['model']" + if self.should_save_optimizer: + assert self.optimizer is not None, "optimizer must be provided when checkpoint_contents.save includes ['optimizer']" + # every rank will save its own model and optim shard state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with warnings.catch_warnings(): warnings.simplefilter("ignore") with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): - model_state_dict = self.model.state_dict() - optimizer_state_dict = self.optimizer.state_dict() if self.optimizer is not None else None - lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None - - extra_state_dict = { - "lr_scheduler": lr_scheduler_state_dict, - "rng": self.get_rng_state(), - } model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") - print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}") - print(f"[rank-{self.rank}]: Saving optim to {os.path.abspath(optim_path)}") - print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}") - torch.save(model_state_dict, model_path) - torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None - torch.save(extra_state_dict, extra_path) + if self.should_save_model: + model_state_dict = self.model.state_dict() + torch.save(model_state_dict, model_path) + log_with_rank(f"Saved model to {os.path.abspath(model_path)}", rank=self.rank, logger=logger) + + if self.should_save_optimizer: + optimizer_state_dict = self.optimizer.state_dict() + torch.save(optimizer_state_dict, optim_path) + log_with_rank(f"Saved optim to {os.path.abspath(optim_path)}", rank=self.rank, logger=logger) + + if self.should_save_extra: + lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None + extra_state_dict = { + "lr_scheduler": lr_scheduler_state_dict, + "rng": self.get_rng_state(), + } + torch.save(extra_state_dict, extra_path) + log_with_rank(f"Saved extra_state to {os.path.abspath(extra_path)}", rank=self.rank, logger=logger) if self.rank == 0: if fsdp_version(self.model) == 1: @@ -201,11 +229,12 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i model_config.save_pretrained(local_path) self.processing_class.save_pretrained(local_path) + log_with_rank(f"Saved model config and tokenizer class to {os.path.abspath(local_path)}", rank=self.rank, logger=logger, log_only_rank_0=True) # wait for everyone to dump to local torch.distributed.barrier() - if "hf_model" in self.checkpoint_contents: + if self.should_save_hf_model: hf_local_path = os.path.join(local_path, "huggingface") os.makedirs(hf_local_path, exist_ok=True) @@ -243,6 +272,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i save_model.save_pretrained(hf_local_path, state_dict=state_dict) self.processing_class.save_pretrained(hf_local_path) + log_with_rank(f"Saved hf_model to {os.path.abspath(hf_local_path)}", rank=self.rank, logger=logger, log_only_rank_0=True) del state_dict del save_model diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index cd6b0f8cd31..6125b49261c 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import random -from typing import Optional import numpy as np import torch import torch.distributed from megatron.core import mpu, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject +from omegaconf import DictConfig from transformers import GenerationConfig from verl.models.weight_loader_registry import get_weight_saver from verl.utils.fs import is_non_local +from verl.utils.logger import log_with_rank from verl.utils.megatron_utils import ( get_hf_config_and_tokenizer_checkpoint_path, get_hf_model_checkpoint_path, @@ -36,6 +38,10 @@ from .checkpoint_manager import BaseCheckpointManager +# Setup logging +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + class MegatronCheckpointManager(BaseCheckpointManager): """ @@ -67,11 +73,9 @@ def __init__( optimizer_scheduler, use_distributed_optimizer: bool, use_checkpoint_opt_param_scheduler: bool = False, - checkpoint_contents: Optional[list] = None, + checkpoint_contents: DictConfig = None, **kwargs, ): - if checkpoint_contents is None: - checkpoint_contents = ["model", "optimizer", "extra"] super().__init__( model, optimizer=optimizer, @@ -180,12 +184,12 @@ def get_checkpoint_name( def load_optimizer(self, ckpt_path): # TODO: Check Optimizer format and distributed optimizer optimizer_path = get_optimizer_checkpoint_path(ckpt_path) - print(f"Loading optimizer from {optimizer_path}") + log_with_rank(f"Loading optimizer from {optimizer_path}", rank=self.rank, logger=logger) self.optimizer.load_parameter_state(optimizer_path) def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_ckpt=False): rng_state_path = get_rng_states_checkpoint_path(ckpt_path, only_rank0_save=False) - print(f"Loading rng states from {rng_state_path}") + log_with_rank(f"Loading rng states from {rng_state_path}", rank=self.rank, logger=logger) rng_state = torch.load(rng_state_path, weights_only=False) # access rng_state for data parallel rank if not use_dist_ckpt: @@ -203,35 +207,35 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte if local_path is None: return - if "model" in self.checkpoint_contents: + if self.should_load_model: model_path = get_model_checkpoint_path(local_path) ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False) state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False) assert len(state_dicts) == len(self.model), f"state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}" for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)): model.load_state_dict(state_dict) - print(f"Loaded sharded model checkpoint from {model_path}") + log_with_rank(f"Loaded sharded model checkpoint from {model_path}", rank=self.rank, logger=logger) - if "optimizer" in self.checkpoint_contents: + if self.should_load_optimizer: self.load_optimizer(local_path) - if "extra" in self.checkpoint_contents: + if self.should_load_extra: self.load_rng_states(local_path) if self.use_checkpoint_opt_param_scheduler: - optimizer_scheduler_path = get_optimizer_scheduler_checkpoint_path(local_path, only_rank0_save=False) + optimizer_scheduler_path = get_optimizer_scheduler_checkpoint_path(local_path, only_rank0_save=True) if os.path.exists(optimizer_scheduler_path): - print(f"Loading optimizer scheduler from {optimizer_scheduler_path}") + log_with_rank(f"Loading optimizer scheduler from {optimizer_scheduler_path}", rank=self.rank, logger=logger) state_dict = torch.load(optimizer_scheduler_path, weights_only=False) if self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(state_dict) else: - print(f"Optimizer scheduler path {optimizer_scheduler_path} does not exist, skipping loading.") + log_with_rank(f"Optimizer scheduler path {optimizer_scheduler_path} does not exist, skipping loading.", rank=self.rank, logger=logger) if del_local_after_load: try: os.remove(local_path) if is_non_local(local_path) else None except Exception as e: - print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored") + log_with_rank(f"remove local resume ckpt file after loading failed, exception {e} will be ignored", rank=self.rank, logger=logger) def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): # record the previous global step @@ -246,20 +250,20 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i local_path = self.local_mkdir(local_path) # Save Model - if "model" in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0: + if self.should_save_model and mpu.get_data_parallel_rank() == 0: state_dicts = [] for vpp_rank, model in enumerate(self.model): state_dict = model.state_dict() state_dicts.append(state_dict) - print(f"Saving sharded model checkpoint to {local_path}") + log_with_rank(f"Saving sharded model checkpoint to {local_path}", rank=self.rank, logger=logger) model_ckpt_path = get_model_checkpoint_path(local_path) hf_config_and_tokenizer_path = get_hf_config_and_tokenizer_checkpoint_path(local_path) ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False) torch.save(state_dicts, os.path.join(ckpt_name)) - print(f"Saved checkpoint to {model_ckpt_path}") + log_with_rank(f"Saved checkpoint to {model_ckpt_path}", rank=self.rank, logger=logger) if self.rank == 0: self.processing_class.save_pretrained(hf_config_and_tokenizer_path) self.hf_config.save_pretrained(hf_config_and_tokenizer_path) @@ -271,14 +275,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i # if the generation config isn't available, we don't save it pass if hdfs_path is not None: - print(f"Uploading checkpoint to {hdfs_path}") + log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger) from verl.utils import hdfs_io hdfs_io.makedirs(hdfs_path, exist_ok=True) hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) hdfs_io.copy(src=hf_config_and_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True) - if "hf_model" in self.checkpoint_contents: + if self.should_save_hf_model: # wait for everyone to dump to local state_dict = self.weight_saver( self.model, @@ -290,9 +294,6 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i torch.distributed.barrier() if self.rank == 0: - print(f"self.param_dtype: {self.param_dtype}") - for key in state_dict.keys(): - print(f"state_dict[key].dtype: {key} {state_dict[key].dtype}") hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) import warnings @@ -311,37 +312,38 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) self.processing_class.save_pretrained(hf_model_ckpt_path) + log_with_rank(f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", rank=self.rank, logger=logger, log_only_rank_0=True) if hdfs_path is not None: - print(f"Uploading checkpoint to {hdfs_path}") + log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True) from verl.utils import hdfs_io hdfs_io.makedirs(hdfs_path, exist_ok=True) hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) + log_with_rank(f"HDFS checkpoint uploaded to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True) # Save Optimizer - if "optimizer" in self.checkpoint_contents: + if self.should_save_optimizer: torch.distributed.barrier() optimizer_path = get_optimizer_checkpoint_path(local_path) self.optimizer.save_parameter_state(optimizer_path) - if self.rank == 0: - print(f"saving optimizer state to {optimizer_path}") + log_with_rank(f"Saved optimizer state to {optimizer_path}", rank=self.rank, logger=logger) # Save RNG States - if "extra" in self.checkpoint_contents: + if self.should_save_extra: torch.distributed.barrier() rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False) rng_state = self.get_rng_state() torch.save(rng_state, rng_state_path) - print(f"Rank {self.rank} saving rng states to {rng_state_path}") - - optimizer_scheduler_path = get_optimizer_scheduler_checkpoint_path(local_path, only_rank0_save=False) - if self.lr_scheduler is not None: - state_dict = self.lr_scheduler.state_dict() - torch.save(state_dict, optimizer_scheduler_path) - if self.rank == 0: - print(f"Rank {self.rank} saving optimizer scheduler state to {optimizer_scheduler_path}") + log_with_rank(f"Saved rng states to {rng_state_path}", rank=self.rank, logger=logger) + + if self.rank == 0: + optimizer_scheduler_path = get_optimizer_scheduler_checkpoint_path(local_path, only_rank0_save=True) + if self.lr_scheduler is not None: + state_dict = self.lr_scheduler.state_dict() + torch.save(state_dict, optimizer_scheduler_path) + log_with_rank(f"Saved optimizer scheduler state to {optimizer_scheduler_path}", rank=self.rank, logger=logger, log_only_rank_0=True) self.previous_saved_paths.append(local_path) diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index c6e7ec3c114..169746b8e77 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -23,7 +23,7 @@ from codetiming import Timer from verl.utils.device import get_torch_device -from verl.utils.logger.aggregate_logger import DecoratorLoggerBase +from verl.utils.logger import DecoratorLoggerBase def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]: diff --git a/verl/utils/logger/__init__.py b/verl/utils/logger/__init__.py index 1ce90c5eb35..4c7e1904216 100644 --- a/verl/utils/logger/__init__.py +++ b/verl/utils/logger/__init__.py @@ -11,3 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from .aggregate_logger import DecoratorLoggerBase, LocalLogger, log_with_rank, print_rank_0, print_with_rank, print_with_rank_and_timer + +__all__ = ["LocalLogger", "DecoratorLoggerBase", "print_rank_0", "print_with_rank", "print_with_rank_and_timer", "log_with_rank"] diff --git a/verl/utils/logger/aggregate_logger.py b/verl/utils/logger/aggregate_logger.py index 47d9945cdcc..6bbb96545cd 100644 --- a/verl/utils/logger/aggregate_logger.py +++ b/verl/utils/logger/aggregate_logger.py @@ -15,10 +15,13 @@ A Ray logger will receive logging info from different processes. """ +import datetime import logging import numbers from typing import Dict +import torch + def concat_dict_to_str(dict: Dict, step): output = [f"step:{step}"] @@ -63,3 +66,57 @@ def log_by_logging(self, log_str): raise ValueError("Logger is not initialized") if not self.log_only_rank_0 or self.rank == 0: self.logger.log(self.level, f"{self.role} {log_str}") + + +def print_rank_0(message): + """If distributed is initialized, print only on rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + + +def print_with_rank(message: str, rank: int = 0, log_only_rank_0: bool = False): + """_summary_ + Print a message with rank information. + This function prints the message only if `log_only_rank_0` is False or if the rank is 0. + + Args: + message (str): _description_ + rank (int, optional): _description_. Defaults to 0. + log_only_rank_0 (bool, optional): _description_. Defaults to False. + """ + if not log_only_rank_0 or rank == 0: + print(f"[Rank {rank}] {message}", flush=True) + + +def print_with_rank_and_timer(message: str, rank: int = 0, log_only_rank_0: bool = False): + """_summary_ + Print a message with rank information and a timestamp. + This function prints the message only if `log_only_rank_0` is False or if the rank is 0. + + Args: + message (str): _description_ + rank (int, optional): _description_. Defaults to 0. + log_only_rank_0 (bool, optional): _description_. Defaults to False. + """ + now = datetime.datetime.now() + message = f"[{now.strftime('%Y-%m-%d %H:%M:%S')}] [Rank {rank}] {message}" + if not log_only_rank_0 or rank == 0: + print(message, flush=True) + + +def log_with_rank(message: str, rank, logger: logging.Logger, level=logging.INFO, log_only_rank_0: bool = False): + """_summary_ + Log a message with rank information using a logger. + This function logs the message only if `log_only_rank_0` is False or if the rank is 0. + Args: + message (str): The message to log. + rank (int): The rank of the process. + logger (logging.Logger): The logger instance to use for logging. + level (int, optional): The logging level. Defaults to logging.INFO. + log_only_rank_0 (bool, optional): If True, only log for rank 0. Defaults to False. + """ + if not log_only_rank_0 or rank == 0: + logger.log(level, f"[Rank {rank}] {message}") diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 3b6b42e69df..69a9b7bcac4 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -420,15 +420,6 @@ def _iter_opts(opt): torch.cuda.empty_cache() -def print_rank_0(message): - """If distributed is initialized, print only on rank 0.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - print(message, flush=True) - else: - print(message, flush=True) - - def get_model_checkpoint_path(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) return os.path.join(checkpoint_path, "model") diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index fde51a5dbaf..0eafa619f58 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -120,7 +120,7 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li self.logger["tensorboard"] = _TensorboardAdapter() if "console" in default_backend: - from verl.utils.logger.aggregate_logger import LocalLogger + from verl.utils.logger import LocalLogger self.console_logger = LocalLogger(print_to_console=True) self.logger["console"] = self.console_logger diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index f31a7dd0329..1bb30a41eff 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -530,7 +530,7 @@ def init_model(self): if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) - # load from checkpoint + if self._is_actor: OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): @@ -567,7 +567,20 @@ def init_model(self): optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents, + checkpoint_contents=self.config.actor.checkpoint, + ) + + if not self._is_actor and self._is_rollout: + # If ActorRolloutRefWorker is initialized as a standalone rollout, + # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout. + + checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=None, + lr_scheduler=None, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_contents=checkpoint_contents, ) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -728,6 +741,8 @@ def compute_ref_log_prob(self, data: DataProto): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + from verl.utils.logger import log_with_rank + # only support save and load ckpt for actor assert self._is_actor @@ -756,18 +771,18 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(peft_config, f, ensure_ascii=False, indent=4) except Exception as e: - if dist.get_rank() == 0: - print(f"[rank-{self.rank}]: Save LoRA Adapter Error ({e})") + log_with_rank(f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True) dist.barrier() - if dist.get_rank() == 0: - print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}") + log_with_rank(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}", rank=dist.get_rank(), logger=logger, log_only_rank_0=True) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): + assert self._is_actor or (not self._is_actor and self._is_rollout), f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got {self._is_actor} and {self._is_rollout}" + if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) @@ -854,7 +869,7 @@ def _build_critic_model_optimizer(self, config): torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) - from transformers import AutoConfig, AutoModelForTokenClassification + from transformers import AutoConfig critic_model_config = AutoConfig.from_pretrained(local_path, attn_implementation="flash_attention_2", trust_remote_code=config.model.get("trust_remote_code", False)) critic_model_config.num_labels = 1 @@ -1023,7 +1038,7 @@ def init_model(self): optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.checkpoint.contents, + checkpoint_contents=self.config.checkpoint, ) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 33187c475e8..46460e04456 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -406,7 +406,7 @@ def init_model(self): optimizer_scheduler=self.actor_optimizer_scheduler, use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler, - checkpoint_contents=self.config.actor.checkpoint.contents, + checkpoint_contents=self.config.actor.checkpoint, ) torch.cuda.empty_cache() log_gpu_memory_usage("After init_model finish", logger=logger) @@ -750,7 +750,7 @@ def init_model(self): optimizer_scheduler=self.critic_optimizer_scheduler, use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler, - checkpoint_contents=self.config.checkpoint.contents, + checkpoint_contents=self.config.checkpoint, ) @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)