Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/split_placement/config/ppo_trainer_split.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ critic:
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
epochs: 1
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ critic:
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
epochs: 1
data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed}
shuffle: ${actor_rollout_ref.actor.shuffle}
cliprange_value: 0.5
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ critic:
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
epochs: 1
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def update_critic(self, data: DataProto):
else:
dataloader = batch.split(self.config.ppo_mini_batch_size)

for epoch in range(self.config.ppo_epochs):
for epoch in range(self.config.epochs):
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/critic/megatron_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
data = data.select(batch_keys=select_keys)
return data.make_iterator(
mini_batch_size=self.config.ppo_mini_batch_size,
epochs=self.config.ppo_epochs,
epochs=self.config.epochs,
seed=self.config.data_loader_seed,
dataloader_kwargs={"shuffle": self.config.shuffle},
)
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def update_critic(self, data: DataProto):

global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
metrics["perf/mfu/critic"] = estimated_flops * self.config.epochs / promised_flops / self.world_size

self.critic_lr_scheduler.step()
lr = self.critic_lr_scheduler.get_last_lr()[0]
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def update_critic(self, data: DataProto):
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
metrics["perf/mfu/critic"] = estimated_flops * self.config.epochs / promised_flops / self.world_size
output = DataProto(batch=None, meta_info={"metrics": metrics})

if self._is_offload_param:
Expand Down