Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 10 additions & 25 deletions examples/grpo_trainer/run_gptoss_20b.sh
Original file line number Diff line number Diff line change
@@ -1,25 +1,5 @@
#!/bin/bash

# install flashinfer
cd $HOME
git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
cd flashinfer
python -m pip install -v .

# install sglang
cd $HOME
git fetch origin pull/9379/head:fix_weight_loading
cd $HOME/sglang
git checkout fix_weight_loading
pip install --upgrade pip
pip install -e "python[all]"

pip install peft
pip install transformers -U
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp311-cp311-linux_x86_64.whl
pip install numpy==1.26.4


cat > get_model.py << EOF
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config
Expand Down Expand Up @@ -47,19 +27,23 @@ tokenizer.save_pretrained(output_dir)
EOF

python get_model.py

# or you can use lmsys/gpt-oss-20b-bf16
# recommend to use same value for train_batch_size and ppo_mini_batch_size
# to avoid MOE training instability
# use large value for max_response_length if you want to use reasoning effort high.


model_dir=$HOME/models/gpt-oss-20b-bf16
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files="$gsm8k_train_path" \
data.val_files="$gsm8k_test_path" \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.train_batch_size=256 \
data.max_prompt_length=512 \
data.max_response_length=8192 \
data.filter_overlong_prompts=True \
data.truncation='error' \
+data.apply_chat_template_kwargs.reasoning_effort=medium \
actor_rollout_ref.model.path=${model_dir} \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
Expand All @@ -76,8 +60,9 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.mode=sync \
actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
Expand Down