Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update readme
  • Loading branch information
bwasti committed Nov 4, 2025
commit 53b56a5ea5cb3689c73c99df8c75528a281de912
116 changes: 55 additions & 61 deletions torchtitan/experiments/deterministic_vllm_rl/README.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we mention that this only works for single device right now, and we plan to extend it to work with parallelisms?

Original file line number Diff line number Diff line change
@@ -1,51 +1,50 @@
# Deterministic RL Training with vLLM

This experiment provides a complete framework for **bitwise-deterministic reinforcement learning training** that combines:
- **vLLM** for fast, deterministic rollouts
- **TorchTitan** for efficient training with gradients
- **Custom backward passes** to maintain determinism through the entire training loop
This experiment combines vLLM's deterministic kernels with PyTorch autograd to enable reinforcement learning training where forward passes produce bitwise-identical results across runs.

## Overview

Traditional RL training faces a challenge: you want fast inference for generating rollouts, but you also need gradients for training. vLLM is extremely fast but doesn't support gradients. Standard PyTorch supports gradients but can be non-deterministic.
RL training requires both fast inference for generating rollouts and gradient computation for policy updates. vLLM provides deterministic forward passes but does not support gradients. This experiment adds backward passes to vLLM's operations.

This experiment solves both problems by:
1. Using vLLM's deterministic kernels for forward passes (both rollouts and training)
2. Adding custom backward passes that are also deterministic
3. Achieving **bitwise-identical results** across runs for the entire training loop
The implementation:
1. Uses vLLM's batch-invariant kernels for forward passes
2. Implements custom backward passes for gradient computation
3. Provides weight conversion utilities between TorchTitan and vLLM formats

### Key Features
### Features

- **Bitwise Determinism**: Same inputs always produce identical outputs (bit-for-bit)
- **vLLM Speed**: Fast rollouts using vLLM's optimized kernels
- **Gradient Support**: Full backward pass support for training
- **Model Compatibility**: Drop-in replacement for standard Qwen3 models in TorchTitan
- Bitwise determinism: Same inputs produce identical outputs across runs
- Gradient support: Backward passes through vLLM operations
- Weight conversion: Utilities to convert between model formats

**Note**: This experiment currently supports single-device training only. We plan to extend support for distributed training with tensor parallelism and pipeline parallelism in the future.
Note: Currently supports single-device training only.

## Architecture

### Components

1. **`models/attention.py`**: VLLMCompatibleFlashAttention
1. `models/attention.py`: VLLMCompatibleFlashAttention
- Uses vLLM's Flash Attention for forward pass
- Implements custom backward pass for gradients
- Maintains determinism with `num_splits=1`
- Implements custom backward pass for gradient computation
- Uses `num_splits=1` for deterministic behavior

2. **`models/qwen3/model_vllm_compat.py`**: Qwen3VLLMCompatModel
- vLLM-compatible Qwen3 implementation
- Merged gate/up projections (like vLLM)
2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel
- Qwen3 model with merged gate/up projections matching vLLM format
- Uses VLLMRMSNorm with gradient support

3. **`batch_invariant_backward.py`**: Gradient support for vLLM operations
- Registers backward passes for vLLM's batch-invariant operations
3. `batch_invariant_backward.py`: Backward passes for vLLM operations
- Registers gradients for vLLM's batch-invariant operations
- Supports matmul, linear, and RMSNorm
- Patches Flash Attention to work with autograd
- Patches Flash Attention for autograd

4. **`simple_rl.py`**: End-to-end RL training loop
- Generates rollouts using vLLM
4. `weights_vllm_compat.py`: Weight conversion utilities
- Converts between TorchTitan format (separate w1, w2, w3) and vLLM format (merged gate_up_proj)
- Provides bidirectional conversion functions

5. `simple_rl.py`: RL training loop
- Generates rollouts using vLLM engine
- Computes advantages using GRPO-style ranking
- Updates policy using PPO with bitwise-deterministic gradients
- Updates policy using PPO

## Installation

Expand All @@ -64,7 +63,7 @@ pip install transformers safetensors huggingface_hub tensorboard

### Enable Batch Invariance

Before running any training, you must initialize vLLM's batch-invariant mode:
Initialize vLLM's batch-invariant mode before training:

```python
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
Expand Down Expand Up @@ -102,14 +101,14 @@ model = Qwen3VLLMCompatModel(model_args)
input_ids = torch.randint(0, 151936, (2, 128), device='cuda')
logits = model(input_ids)

# 4. Backward pass (also deterministic!)
# 4. Backward pass
loss = logits.sum()
loss.backward()
```

### Full RL Training

Run the complete RL training loop:
Run the RL training loop:

```bash
VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl
Expand All @@ -132,7 +131,7 @@ tensorboard --logdir=./outputs/rl_training

### Deterministic Forward Pass

vLLM's batch-invariant mode ensures that all operations are deterministic:
vLLM's batch-invariant mode makes operations deterministic:

```python
# These operations are deterministic when batch_invariance is enabled
Expand All @@ -142,16 +141,16 @@ output = flash_attn_varlen_func(q, k, v, num_splits=1) # Deterministic FA

### Backward Pass with Gradients

We add custom backward passes that:
1. Re-compute attention weights (deterministic)
Custom backward passes:
1. Re-compute attention weights deterministically
2. Use standard chain rule for gradients
3. Apply gradients through vLLM's deterministic operations

```python
class FlashAttnWithBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, ...):
# Use vLLM's fast forward
# Use vLLM's forward implementation
return flash_attn_varlen_func(q, k, v, num_splits=1, ...)

@staticmethod
Expand All @@ -163,79 +162,74 @@ class FlashAttnWithBackward(torch.autograd.Function):

### Bitwise Determinism Verification

The training loop verifies that vLLM and TorchTitan produce identical logprobs:
The training loop compares logprobs from vLLM and TorchTitan:

```python
# During training, compare logprobs
vllm_logprobs = [from vLLM rollout]
titan_logprobs = [from TorchTitan forward pass]

assert torch.equal(vllm_logprobs, titan_logprobs) # Should be true!
assert torch.equal(vllm_logprobs, titan_logprobs)
```

## Testing

Run the test suite to verify determinism:
Run the test suite:

```bash
cd torchtitan/experiments/deterministic_vllm_rl/tests

# Test backward passes work correctly
# Test backward passes
python test_batch_invariant_backward.py

# Test exact determinism (bit-for-bit identical)
# Test determinism
python test_exact_determinism.py
```

Expected output:
```
✓ All operations are exactly deterministic!
✓ vLLM-TorchTitan bitwise determinism verified: N tokens match exactly
```

## Technical Details

### Why Determinism Matters for RL

In RL training, we need to:
1. Generate rollouts (sampling from the policy)
RL training steps:
1. Generate rollouts by sampling from the policy
2. Compute rewards based on the samples
3. Update the policy using gradients

**The problem**: If the forward pass during training differs from the forward pass during rollout, the gradients will be wrong! This is especially important for PPO, which compares old and new policy probabilities.
If the forward pass during training differs from the forward pass during rollout, policy gradients may be incorrect. This matters for algorithms like PPO that compare old and new policy probabilities.

**The solution**: Use the same deterministic kernels for both rollouts (vLLM) and training (TorchTitan). This ensures that `logprobs_rollout == logprobs_training` (bitwise).
This implementation uses the same kernels for both rollouts (vLLM) and training (TorchTitan) to ensure `logprobs_rollout == logprobs_training` bitwise.

### Performance

- **Rollout speed**: ~100x faster than standard PyTorch (thanks to vLLM)
- **Training speed**: Same as standard TorchTitan
- **Memory**: Slightly higher (saves activations for custom backward)
- Rollout speed: Uses vLLM's optimized kernels
- Training speed: Similar to standard TorchTitan
- Memory: Saves activations for custom backward passes

### Limitations

1. **Sequence length**: Custom backward requires uniform sequence lengths
2. **Attention**: Only causal attention is supported
3. **Hardware**: Requires NVIDIA GPUs with Flash Attention support
1. Custom backward requires uniform sequence lengths
2. Only causal attention is supported
3. Requires NVIDIA GPUs with Flash Attention support

## Project Structure

```
deterministic_vllm_rl/
├── README.md # This file
├── README.md # Documentation
├── __init__.py # Package initialization
├── batch_invariant_backward.py # Gradient support for vLLM ops
├── simple_rl.py # End-to-end RL training loop
├── batch_invariant_backward.py # Backward passes for vLLM ops
├── weights_vllm_compat.py # Weight conversion utilities
├── simple_rl.py # RL training loop
├── models/
│ ├── __init__.py
│ ├── attention.py # VLLMCompatibleFlashAttention
│ └── qwen3/
│ ├── __init__.py
│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model
│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model
└── tests/
├── __init__.py
├── test_batch_invariant_backward.py # Test gradients work
└── test_exact_determinism.py # Test bitwise determinism
├── test_batch_invariant_backward.py # Test backward passes
└── test_exact_determinism.py # Test determinism
```

## Contributing
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/deterministic_vllm_rl/simple_rl.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar, and also the two test files

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def main():
output_dir = "./converted"

group_size = 4 # For GRPO - samples per prompt
num_steps = 2 # Quick test - change to 100 for full training
num_steps = 100
learning_rate = 1e-5

# Check if batch invariance is enabled
Expand Down
Loading