Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove gradient debugging statements
Remove temporary debugging code added for gradient checking:
- Removed print_gradient_report() function
- Removed gradient tracking code in rl_update_step()
- Removed gradient-related console output
- Removed gradient metrics from TensorBoard logging
- Removed step parameter from rl_update_step() (no longer needed)

The gradient issue has been fixed, so this debugging code is no longer needed.
  • Loading branch information
bwasti committed Nov 4, 2025
commit d53740a340d3d6a43e3fc318bb454b4ee3fa2c53
101 changes: 0 additions & 101 deletions torchtitan/experiments/deterministic_vllm_rl/simple_rl.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar, and also the two test files

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,6 @@ def rl_update_step(
max_new_tokens: int = 20,
temperature: float = 1.0,
use_vllm_compat: bool = True,
step: int = 0,
) -> dict:
"""
Perform one RL update step using vLLM for rollouts.
Expand All @@ -608,7 +607,6 @@ def rl_update_step(
max_new_tokens: Max tokens to generate
temperature: Sampling temperature
use_vllm_compat: Whether to use vLLM-compatible model
step: Current training step (for gradient reporting)

Returns:
metrics: Dict of training metrics
Expand Down Expand Up @@ -651,27 +649,6 @@ def rl_update_step(
)
loss.backward()

# Check which parameters have gradients
params_with_grad = 0
params_without_grad = 0
params_without_grad_names = []
total_grad_norm = 0.0

for name, param in model.named_parameters():
if param.requires_grad:
if param.grad is not None:
params_with_grad += 1
total_grad_norm += param.grad.norm().item() ** 2
else:
params_without_grad += 1
params_without_grad_names.append(name)

total_grad_norm = total_grad_norm ** 0.5

# Print detailed gradient report on first step (before optimizer clears gradients)
if step == 0:
print_gradient_report(model, step)

# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Expand All @@ -685,81 +662,12 @@ def rl_update_step(
"advantage_mean": advantages.mean().item(),
"advantage_std": advantages.std().item(),
"sample_completions": completions[:2], # First 2 for inspection
"params_with_grad": params_with_grad,
"params_without_grad": params_without_grad,
"params_without_grad_names": params_without_grad_names,
"total_grad_norm": total_grad_norm,
**loss_metrics, # Include all loss metrics (pg_loss, kl_div, entropy, ratio stats, logprob comparisons)
}

return metrics


def print_gradient_report(model: torch.nn.Module, step: int = 0):
"""
Print a detailed report of which parameters have gradients.

Args:
model: The model to inspect
step: Current training step (for logging)
"""
print("\n" + "=" * 80)
print(f"GRADIENT REPORT (Step {step})")
print("=" * 80)

params_by_module = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue

# Get module name (everything before the last dot)
module_name = '.'.join(name.split('.')[:-1]) if '.' in name else 'root'

if module_name not in params_by_module:
params_by_module[module_name] = {'with_grad': [], 'without_grad': []}

if param.grad is not None:
grad_norm = param.grad.norm().item()
params_by_module[module_name]['with_grad'].append((name, grad_norm))
else:
params_by_module[module_name]['without_grad'].append(name)

# Print summary by module
total_with = 0
total_without = 0

for module_name in sorted(params_by_module.keys()):
info = params_by_module[module_name]
n_with = len(info['with_grad'])
n_without = len(info['without_grad'])
total_with += n_with
total_without += n_without

if n_without > 0:
status = "⚠ MISSING GRADS"
else:
status = "✓"

print(f"\n{status} {module_name}:")
print(f" Params with grad: {n_with}, without grad: {n_without}")

if n_without > 0:
print(f" Missing gradients for:")
for pname in info['without_grad']:
print(f" - {pname}")

# Show top 5 params by gradient norm
if n_with > 0:
top_grads = sorted(info['with_grad'], key=lambda x: x[1], reverse=True)[:5]
print(f" Top gradient norms:")
for pname, gnorm in top_grads:
print(f" - {pname}: {gnorm:.6e}")

print("\n" + "=" * 80)
print(f"TOTAL: {total_with} params with grad, {total_without} params without grad")
print("=" * 80 + "\n")


def compute_weight_deltas(model: torch.nn.Module, initial_state: dict) -> dict:
"""
Compute weight changes from initial state based on magnitude (L2 norm).
Expand Down Expand Up @@ -903,7 +811,6 @@ def main():
max_new_tokens=20,
temperature=1.0,
use_vllm_compat=use_vllm_compat,
step=step,
)

# Compute weight deltas from initial state
Expand All @@ -920,9 +827,6 @@ def main():
writer.add_scalar('rl/reward_std', metrics['reward_std'], step)
writer.add_scalar('rl/advantage_mean', metrics['advantage_mean'], step)
writer.add_scalar('rl/advantage_std', metrics['advantage_std'], step)
writer.add_scalar('rl/total_grad_norm', metrics['total_grad_norm'], step)
writer.add_scalar('rl/params_with_grad', metrics['params_with_grad'], step)
writer.add_scalar('rl/params_without_grad', metrics['params_without_grad'], step)

# Log weight deltas
for key, value in weight_deltas.items():
Expand All @@ -931,11 +835,6 @@ def main():
print(f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | "
f"Reward: {metrics['reward_mean']:+.3f}±{metrics['reward_std']:.3f} | "
f"Advantage: {metrics['advantage_mean']:+.3f}±{metrics['advantage_std']:.3f}")
print(f" Grad: {metrics['params_with_grad']} params with grad, "
f"{metrics['params_without_grad']} without | "
f"Grad norm: {metrics['total_grad_norm']:.4f}")
if metrics['params_without_grad'] > 0:
print(f" ⚠ Params without gradients: {metrics['params_without_grad_names'][:5]}")
print(f" Sample: {metrics['sample_completions'][0][:60]}...")

print("\n" + "=" * 80)
Expand Down