Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fcdc0f5
initial test
tinuademargaret Apr 8, 2025
d463293
initial profiling
Apr 8, 2025
4b91919
initial fusion test
tinuademargaret Apr 8, 2025
2d8e0bc
fixes
tinuademargaret Apr 8, 2025
3ff92ac
fix lr scheduler
tinuademargaret Apr 10, 2025
f991417
debugging fusion with gradient accumulation
tinuademargaret Apr 11, 2025
50c8db3
use flag
tinuademargaret Apr 17, 2025
7d263e2
fixes
tinuademargaret Apr 21, 2025
b632651
tests
tinuademargaret Apr 23, 2025
81930ea
test flat params
tinuademargaret Apr 23, 2025
74e09c7
update params to flat params
tinuademargaret Apr 23, 2025
281fc87
sft optimiser fuse
tinuademargaret Apr 23, 2025
bf2a0a7
fix batch size
tinuademargaret Apr 23, 2025
3a5d7e9
fix normalise bsz
tinuademargaret Apr 23, 2025
c6a39b7
test ppo config
tinuademargaret Apr 23, 2025
ac0d426
fix config
tinuademargaret Apr 24, 2025
20af206
add bwd hook to actor worker
tinuademargaret Apr 24, 2025
55c35ef
update sft trainer
tinuademargaret Apr 24, 2025
e7b7acc
fixes
tinuademargaret Apr 24, 2025
d827d11
update critic
tinuademargaret Apr 25, 2025
22565e1
fix
tinuademargaret Apr 25, 2025
f2f94ef
delete prev memory
tinuademargaret Apr 28, 2025
ccefe19
remove prev scheduler
tinuademargaret Apr 28, 2025
70a1859
Merge branch 'main' into feat-optimiser-fuse
tinuademargaret Apr 28, 2025
2c43fa5
revert changes for rl workers
tinuademargaret May 7, 2025
ccd3658
update sft config
tinuademargaret May 7, 2025
a098aa0
clean up
tinuademargaret May 8, 2025
7df87b6
Merge branch 'main' into feat-optimiser-fuse
tinuademargaret May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/ppo_trainer/run_qwen2.5-0.5b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
set -x

gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
math_train_path=$HOME/data/math/train.parquet
math_test_path=$HOME/data/math/test.parquet

train_files="['$gsm8k_train_path']"
test_files="['$gsm8k_test_path']"

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=gae \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=8 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=Qwen/Qwen2.5-Coder-0.5B \
actor_rollout_ref.model.enable_gradient_checkpointing=False \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
critic.model.enable_gradient_checkpointing=False \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_example' \
trainer.experiment_name='Qwen2.5-0.5B-Instruct_critic' \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_epochs=1 $@
4 changes: 3 additions & 1 deletion verl/trainer/config/sft_trainer.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
data:
train_batch_size: 256
val_batch_size: 4
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
micro_batch_size_per_gpu: 4 # this is also val batch size
micro_batch_size_per_gpu: 4
train_files: ~/data/gsm8k/train.parquet
val_files: ~/data/gsm8k/test.parquet
# Single-turn settings
Expand Down Expand Up @@ -41,6 +42,7 @@ optim:
warmup_steps_ratio: 0.1
clip_grad: 1.0
lr_scheduler: cosine
bwd_hook: False
ulysses_sequence_parallel_size: 1
use_remove_padding: False
trainer:
Expand Down
174 changes: 112 additions & 62 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
import hydra
import torch
import torch.distributed
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from torch.distributed.optim import _apply_optimizer_in_backward
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from peft import LoraConfig, TaskType, get_peft_model
from tensordict import TensorDict
Expand Down Expand Up @@ -57,8 +62,9 @@
)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager


logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN"))
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))


def extract_step(path):
Expand Down Expand Up @@ -91,15 +97,30 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM
if self.config.data.chat_template is not None:
raise ValueError("Apply Chat template from config is not supported yet.")

self.micro_batch_size = self.config.data.micro_batch_size_per_gpu

# normalize dp size
self._normalize_config_bsz()

# Set sequence parallel size
self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1)
self.use_remove_padding = getattr(self.config, "use_remove_padding", False)
if self.device_mesh.get_rank() == 0:
print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}")
print(f"Using remove padding: {self.use_remove_padding}")
print(f'Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}')
print(f'Using remove padding: {self.use_remove_padding}')

self.optimizer = None
self.optim_bwd_hook = self.config.optim.bwd_hook
self.optim_dict = None
self.lr_scheduler = None


# Optimizer in backward is not compatible with gradient accumulation
if self.optim_bwd_hook:
if self.micro_batch_size > 0:
raise RuntimeError(
"Gradient accumulation is not compatible with optimizer in backward step"
)

self._build_dataloader(train_dataset, val_dataset)
# build model
Expand All @@ -114,11 +135,16 @@ def _normalize_config_bsz(self):
if self.device_mesh.get_rank() == 0:
print(f"Normalize batch size by dp {dp_size}")

assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}"
assert self.config.data.train_batch_size % dp_size == 0, f"Global train batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}"

assert self.config.data.val_batch_size % dp_size == 0, f"Global val batch size {self.config.data.val_batch_size} is not divisible by dp size {dp_size}"

self.config.data.train_batch_size //= dp_size

assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0
self.config.data.val_batch_size //= dp_size

if self.micro_batch_size > 0:
assert self.config.data.train_batch_size % self.micro_batch_size == 0

def _build_dataloader(self, train_dataset, val_dataset):
# build dataset
Expand All @@ -139,27 +165,31 @@ def _build_dataloader(self, train_dataset, val_dataset):
rank = self.device_mesh.get_rank()
world_size = self.device_mesh.size()
if self.device_mesh.get_rank() == 0:
print(f"Using FSDP rank {rank} and size {world_size} for data distribution")

self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True)
self.train_dataloader = DataLoader(
dataset=self.train_dataset,
batch_size=config.data.train_batch_size,
sampler=self.train_sampler,
num_workers=8,
pin_memory=True,
drop_last=True,
)

self.val_sampler = DistributedSampler(self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True)
self.val_dataloader = DataLoader(
dataset=self.val_dataset,
batch_size=config.data.micro_batch_size_per_gpu,
sampler=self.val_sampler,
num_workers=8,
pin_memory=True,
drop_last=True,
)
print(f'Using FSDP rank {rank} and size {world_size} for data distribution')

self.train_sampler = DistributedSampler(self.train_dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
drop_last=True)
self.train_dataloader = DataLoader(dataset=self.train_dataset,
batch_size=config.data.train_batch_size,
sampler=self.train_sampler,
num_workers=8,
pin_memory=True,
drop_last=True)

self.val_sampler = DistributedSampler(self.val_dataset,
shuffle=False,
num_replicas=world_size,
rank=rank,
drop_last=True)
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=config.data.val_batch_size,
sampler=self.val_sampler,
num_workers=8,
pin_memory=True,
drop_last=True)

def _build_model_optimizer(self):
# TODO (zhangchi.usc1992):
Expand Down Expand Up @@ -236,29 +266,31 @@ def _build_model_optimizer(self):
else:
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)

self.fsdp_model = FSDP(
module=self.model,
auto_wrap_policy=auto_wrap_policy,
param_init_fn=init_fn,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=self.device_mesh,
sync_module_states=True,
device_id=torch.cuda.current_device(),
cpu_offload=cpu_offload,
use_orig_params=False,
)

log_gpu_memory_usage("After FSDP wrapping", logger=logger)

self.optimizer = optim.AdamW(
self.fsdp_model.parameters(),
lr=self.config.optim.lr,
betas=self.config.optim.betas,
weight_decay=self.config.optim.weight_decay,
)
self.fsdp_model = FSDP(module=self.model,
auto_wrap_policy=auto_wrap_policy,
param_init_fn=init_fn,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=self.device_mesh,
sync_module_states=True,
device_id=torch.cuda.current_device(),
cpu_offload=cpu_offload,
use_orig_params=True)

log_gpu_memory_usage('After FSDP wrapping', logger=logger)

from verl.utils.torch_functional import apply_optimizer_in_backward, get_cosine_schedule_with_warmup, update_scheduler_with_custom_step

if self.optim_bwd_hook:
optim_dict = apply_optimizer_in_backward(self.fsdp_model, self.config.optim)
self.optimizer = next(iter(optim_dict.values()))
else:
self.optimizer = optim.AdamW(self.fsdp_model.parameters(),
lr=self.config.optim.lr,
betas=self.config.optim.betas,
weight_decay=self.config.optim.weight_decay)

log_gpu_memory_usage("After initialize optimizer", logger=logger)
log_gpu_memory_usage('After initialize optimizer', logger=logger)

self.steps_per_epoch = len(self.train_dataloader)
self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs
Expand All @@ -275,6 +307,9 @@ def _build_model_optimizer(self):
else:
raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}")

if self.optim_bwd_hook:
update_scheduler_with_custom_step(self.lr_scheduler, optim_dict)

def _compute_loss_and_backward(self, batch, do_backward=True):
"""Compute loss with optional sequence parallelism and remove padding features"""
use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1
Expand Down Expand Up @@ -367,27 +402,33 @@ def training_step(self, batch: TensorDict):

log_gpu_memory_usage("Before optimizer zero_grad", logger=logger)

self.optimizer.zero_grad()

if not self.optim_bwd_hook:
self.optimizer.zero_grad()
log_gpu_memory_usage("After optimizer zero_grad", logger=logger)

micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
n_micro_batches = len(micro_batches)

step_loss = 0
for micro_batch in micro_batches:
loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches

if self.micro_batch_size > 0:
micro_batches = batch.split(self.micro_batch_size)
n_micro_batches = len(micro_batches)
for micro_batch in micro_batches:
loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches
step_loss += loss.item()
else:
loss = self._compute_loss_and_backward(batch)
step_loss += loss.item()

grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)

log_gpu_memory_usage("Before optimizer step", logger=logger)

# if grad_norm is not finite, skip the update
if not torch.isfinite(grad_norm):
print(f"WARN: grad_norm is not finite: {grad_norm}")
self.optimizer.zero_grad()
else:
self.optimizer.step()
if not self.optim_bwd_hook:
if not torch.isfinite(grad_norm):
print(f"WARN: grad_norm is not finite: {grad_norm}")
self.optimizer.zero_grad()
else:
self.optimizer.step()

log_gpu_memory_usage("After optimizer step", logger=logger)

Expand Down Expand Up @@ -471,7 +512,7 @@ def fit(self):
# Perform final validation
val_losses = []
for val_data in self.val_dataloader:
val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda()
val_data = TensorDict(val_data, batch_size=self.config.data.val_batch_size).cuda()
val_loss = self.validation_step(val_data)
val_losses.append(val_loss)
if rank == 0:
Expand All @@ -487,7 +528,7 @@ def fit(self):
# validation
val_losses = []
for data in self.val_dataloader:
data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda()
data = TensorDict(data, batch_size=self.config.data.val_batch_size).cuda()
val_loss = self.validation_step(data)
val_losses.append(val_loss)
if rank == 0:
Expand All @@ -500,7 +541,16 @@ def fit(self):
self.save_checkpoint(step=global_step)


from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
import hydra

from torch.distributed.device_mesh import init_device_mesh

from verl.utils.distributed import initialize_global_process_group


@hydra.main(config_path="config", config_name="sft_trainer", version_base=None)

def main(config):
local_rank, rank, world_size = initialize_global_process_group()

Expand Down
21 changes: 12 additions & 9 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ def _validate_config(self):

# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % n_gpus == 0, f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
assert real_train_batch_size % n_gpus == 0, \
f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."

# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
Expand All @@ -353,15 +354,16 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")

if mbs is not None and mbs_per_gpu is not None:
raise ValueError(f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove '{name}.{param}' because only '*_{param_per_gpu}'" + "is supported (the former is deprecated).")
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. "
f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
)

if not config.actor_rollout_ref.actor.use_dynamic_bsz:
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.actor.ppo_micro_batch_size,
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
"actor_rollout_ref.actor",
)
check_mutually_exclusive(config.actor_rollout_ref.actor.ppo_micro_batch_size,
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
"actor_rollout_ref.actor")

if self.use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
Expand All @@ -380,7 +382,8 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):

if self.use_critic and not config.critic.use_dynamic_bsz:
# Check for critic micro-batch size conflicts
check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic")
check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu,
"critic")

# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
Expand All @@ -393,7 +396,7 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
sp_size = config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1)
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
Expand Down
1 change: 1 addition & 0 deletions verl/utils/debug/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging
if logger is None:
print(message)
else:
print(message)
logger.log(msg=message, level=level)


Expand Down
Loading