Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fcdc0f5
initial test
tinuademargaret Apr 8, 2025
d463293
initial profiling
Apr 8, 2025
4b91919
initial fusion test
tinuademargaret Apr 8, 2025
2d8e0bc
fixes
tinuademargaret Apr 8, 2025
3ff92ac
fix lr scheduler
tinuademargaret Apr 10, 2025
f991417
debugging fusion with gradient accumulation
tinuademargaret Apr 11, 2025
50c8db3
use flag
tinuademargaret Apr 17, 2025
7d263e2
fixes
tinuademargaret Apr 21, 2025
b632651
tests
tinuademargaret Apr 23, 2025
81930ea
test flat params
tinuademargaret Apr 23, 2025
74e09c7
update params to flat params
tinuademargaret Apr 23, 2025
281fc87
sft optimiser fuse
tinuademargaret Apr 23, 2025
bf2a0a7
fix batch size
tinuademargaret Apr 23, 2025
3a5d7e9
fix normalise bsz
tinuademargaret Apr 23, 2025
c6a39b7
test ppo config
tinuademargaret Apr 23, 2025
ac0d426
fix config
tinuademargaret Apr 24, 2025
20af206
add bwd hook to actor worker
tinuademargaret Apr 24, 2025
55c35ef
update sft trainer
tinuademargaret Apr 24, 2025
e7b7acc
fixes
tinuademargaret Apr 24, 2025
d827d11
update critic
tinuademargaret Apr 25, 2025
22565e1
fix
tinuademargaret Apr 25, 2025
f2f94ef
delete prev memory
tinuademargaret Apr 28, 2025
ccefe19
remove prev scheduler
tinuademargaret Apr 28, 2025
70a1859
Merge branch 'main' into feat-optimiser-fuse
tinuademargaret Apr 28, 2025
2c43fa5
revert changes for rl workers
tinuademargaret May 7, 2025
ccd3658
update sft config
tinuademargaret May 7, 2025
a098aa0
clean up
tinuademargaret May 8, 2025
7df87b6
Merge branch 'main' into feat-optimiser-fuse
tinuademargaret May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
debugging fusion with gradient accumulation
  • Loading branch information
tinuademargaret committed Apr 11, 2025
commit f9914178c1c89a4feff6c37a2ecb5712822c120c
55 changes: 55 additions & 0 deletions verl/trainer/config/sft_trainer_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
data:
train_batch_size: 256
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
micro_batch_size_per_gpu: 2 # this is also val batch size
train_files: ~/data/gsm8k/train.parquet
val_files: ~/data/gsm8k/test.parquet
# Single-turn settings
prompt_key: extra_info
response_key: extra_info
prompt_dict_keys: ['question']
response_dict_keys: ['answer']
# Multi-turn settings
multiturn:
enable: false # Set to true to use multi-turn dataset
messages_key: messages # Key for messages list in multi-turn mode
max_length: 1024
truncation: error
balance_dp_token: False
chat_template: null
custom_cls:
path: null
name: null
model:
partial_pretrain: Qwen/Qwen2.5-0.5B-Instruct
fsdp_config:
wrap_policy:
min_num_params: 0
cpu_offload: False
offload_params: False
external_lib: null
enable_gradient_checkpointing: False
trust_remote_code: False
lora_rank: 32 # Set to positive value to enable LoRA (e.g., 32)
lora_alpha: 16 # LoRA scaling factor
target_modules: all-linear # Target modules for LoRA adaptation
use_liger: False
optim:
lr: 1e-4
betas: [0.9, 0.95]
weight_decay: 0.01
warmup_steps_ratio: 0.1
clip_grad: 1.0
ulysses_sequence_parallel_size: 1
use_remove_padding: False
trainer:
default_local_dir: ~/data
default_hdfs_dir: null # change the hdfs path here
resume_path: null
project_name: gsm8k-sft
experiment_name: test
total_epochs: 1
total_training_steps: null
logger: ['console']
seed: 1

13 changes: 9 additions & 4 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,16 @@ def _build_model_optimizer(self):

log_gpu_memory_usage('After FSDP wrapping', logger=logger)

batch_size = self.config.data.train_batch_size
microbatchsize = self.config.data.micro_batch_size_per_gpu
acc_steps = batch_size//microbatchsize

self.optimizer = None
if self.optim_bwd_hook:
self.optim_dict = {
param: optim.AdamW([param], lr=self.config.optim.lr, betas=self.config.optim.betas, weight_decay=self.config.optim.weight_decay) for param in self.fsdp_model.parameters()
}
register_optim_in_bwd_hooks(model=self.fsdp_model, optim_dict=self.optim_dict)
register_optim_in_bwd_hooks(model=self.fsdp_model, optim_dict=self.optim_dict, acc_steps=acc_steps)
else:
self.optimizer = optim.AdamW(self.fsdp_model.parameters(),
lr=self.config.optim.lr,
Expand Down Expand Up @@ -430,6 +434,9 @@ def training_step(self, batch: TensorDict):

if not self.optim_bwd_hook:
self.optimizer.zero_grad()
else:
for opt in self.optim_dict.values():
opt.zero_grad()

# log_gpu_memory_usage('After optimizer zero_grad', logger=logger)

Expand Down Expand Up @@ -459,8 +466,6 @@ def training_step(self, batch: TensorDict):
# reduce loss across dp ranks
lr = self.lr_scheduler.get_last_lr()[0]

print("learning rate: {lr}")

log_gpu_memory_usage('After offload weights', logger=logger)

step_loss = torch.tensor(step_loss).cuda()
Expand Down Expand Up @@ -567,7 +572,7 @@ def fit(self):
from verl.utils.distributed import initialize_global_process_group


@hydra.main(config_path='config', config_name='sft_trainer', version_base=None)
@hydra.main(config_path='config', config_name='sft_trainer_test', version_base=None)
def main(config):
local_rank, rank, world_size = initialize_global_process_group()

Expand Down
38 changes: 22 additions & 16 deletions verl/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,32 @@
import torch

def register_optim_in_bwd_hooks(
model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer]
model: torch.nn.Module,
optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer],
acc_steps: int, # number of microbatches to accumulate,
) -> None:
"""
Register hooks for optimizer step running in backward.

When fusing the optimizer step into backward, we need to call ``.step()`` on the optimizer
for a given parameter as soon as its gradient is ready. This utility registers post-accumulate-grad
hooks on all parameters in the model to achieve this.

Args:
model (torch.nn.Module): Model whose parameters will be optimized. Note that currently
hooks for ALL parameters in the model will be registered.
optim_dict (Dict[torch.nn.Parameter, torch.optim.Optimizer]): Mapping from
parameters to optimizers.
Register backward hooks that only perform an optimizer step after `acc_steps`
backward calls on each parameter.
"""

def optim_step(param) -> None:
optim_dict[param].step()
optim_dict[param].zero_grad()
# Get or initialize an accumulation counter on the parameter.
if not hasattr(param, '_accumulation_counter'):
param._accumulation_counter = 0
param._accumulation_counter += 1

# Only update when we've accumulated gradients from all microbatches.
if param._accumulation_counter % acc_steps == 0:
# print("Autocast enabled before optimizer step:", torch.is_autocast_enabled())
# with torch.amp.autocast(device_type='cuda', enabled=False):
# print("Autocast Enabled before optimizer step:", torch.is_autocast_enabled())
param.data = param.data.float()
print(f"Param data type: {param.data.dtype}")
optim_dict[param].step()
# optim_dict[param].zero_grad()
# Resetting or implicitly allowing counter to roll-over
# (optional: you could set param._accumulation_counter = 0)

for p in model.parameters():
if p.requires_grad:
p.register_post_accumulate_grad_hook(optim_step)
p.register_post_accumulate_grad_hook(optim_step)