Skip to content

Commit 13ad653

Browse files
shjwudpyaox12
andauthored
[Dev] Fix Linear-Cross-Entropy Convergence Issue (NVIDIA#2739)
Co-authored-by: Xin Yao <xiny@nvidia.com>
1 parent bfa1d31 commit 13ad653

File tree

5 files changed

+169
-104
lines changed

5 files changed

+169
-104
lines changed

megatron/core/fusions/linear_cross_entropy/blackwell/entry.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ def backward(
345345
and num_valid_tokens.dtype == torch.int64
346346
)
347347

348-
d_hidden = torch.empty_like(global_hidden)
348+
# Allocate d_hidden in float32 for better numerical stability
349+
d_hidden = torch.empty_like(global_hidden, dtype=torch.float32)
349350
d_weight = torch.empty_like(weight)
350351
assert d_hidden.is_contiguous() and d_weight.is_contiguous()
351352

@@ -435,14 +436,15 @@ def backward(
435436
)
436437
valid_d_logits = _d_logits[:, :vocab_right_bound]
437438

438-
torch.addmm(
439-
input=d_hidden.view(-1, dim),
440-
mat1=valid_d_logits,
441-
mat2=weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :],
442-
beta=(split_idx != 0),
443-
alpha=1.0,
444-
out=d_hidden.view(-1, dim),
445-
)
439+
_delta_hidden = torch.mm(
440+
valid_d_logits,
441+
weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :],
442+
out_dtype=torch.float32,
443+
).view_as(d_hidden)
444+
if split_idx == 0:
445+
d_hidden.copy_(_delta_hidden)
446+
else:
447+
d_hidden.add_(_delta_hidden)
446448
torch.matmul(
447449
valid_d_logits.T,
448450
hidden_view,
@@ -466,6 +468,9 @@ def backward(
466468
]
467469
d_hidden = d_hidden.view(partial_hidden_shape).clone()
468470

471+
# convert d_hidden to the original dtype
472+
d_hidden = d_hidden.type_as(global_hidden)
473+
469474
return d_hidden, d_weight
470475

471476
except ImportError:

megatron/core/models/common/language_module/language_module.py

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
import logging
33
import os
4-
from typing import Any, Dict, Literal, Optional, Tuple
4+
from typing import Optional, Tuple
55

66
import torch
77
from torch import Tensor
@@ -14,7 +14,6 @@
1414
except:
1515
te_parallel_cross_entropy = None
1616
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
17-
from megatron.core.fusions.fused_linear_cross_entropy import linear_cross_entropy
1817
from megatron.core.pipeline_parallel.utils import (
1918
is_pp_first_stage,
2019
is_pp_last_stage,
@@ -126,68 +125,6 @@ def check_and_set_env_variable(
126125
check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.auto)
127126
check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.auto)
128127

129-
def compute_output_layer_and_language_model_loss(
130-
self,
131-
hidden: Tensor,
132-
labels: Optional[Tensor],
133-
weight: Tensor = None,
134-
sequence_parallel_enabled: bool = False,
135-
column_parallel_linear: torch.nn.Module = None,
136-
col_linear_kwargs: Dict[str, Any] = {},
137-
reduction: Literal["none", "sum", "mean"] = "none",
138-
ignore_index: int = -100,
139-
) -> Tensor:
140-
"""Computes the language model logits and loss (Cross entropy across vocabulary)
141-
142-
Args:
143-
hidden (Tensor): The hidden states from the transformer model
144-
labels (Optional[Tensor]): The labels of dimension [batch size, seq length]
145-
weight (Tensor): The weight tensor of shape [vocab size, hidden size].
146-
Required if using fused linear cross entropy.
147-
column_parallel_linear (torch.nn.Module): The column parallel linear
148-
layer to use for computing logits when not using fused linear cross entropy.
149-
col_linear_kwargs (Dict[str, Any]): Additional kwargs for column parallel linear layer
150-
reduction (Optional[str]): The reduction method. Defaults to "none", and can be
151-
one of "none", "sum", "mean".
152-
ignore_index (Optional[int]): The index to ignore in the loss calculation.
153-
Defaults to -100.
154-
155-
Returns:
156-
Tensor: Loss tensor of dimensions [batch size, sequence_length].
157-
"""
158-
if (
159-
self.config.cross_entropy_loss_fusion
160-
and self.config.cross_entropy_fusion_impl == 'linear'
161-
):
162-
assert (
163-
weight is not None
164-
), "weight cannot be None when using fused linear cross entropy."
165-
assert (
166-
labels is not None
167-
), "labels cannot be None when using fused linear cross entropy."
168-
# [b s] => [s b]
169-
labels = labels.transpose(0, 1).contiguous()
170-
loss = linear_cross_entropy(
171-
hidden,
172-
weight,
173-
labels,
174-
tp_group=self.pg_collection.tp,
175-
sequence_parallel=sequence_parallel_enabled,
176-
reduction=reduction,
177-
ignore_index=ignore_index,
178-
)
179-
180-
# [s b] => [b, s]
181-
loss = loss.view_as(labels).transpose(0, 1).contiguous()
182-
return loss
183-
else:
184-
assert (
185-
column_parallel_linear is not None
186-
), "column_parallel_linear cannot be None when not using fused linear cross entropy."
187-
logits, _ = column_parallel_linear(hidden, **col_linear_kwargs)
188-
189-
return self.compute_language_model_loss(labels, logits)
190-
191128
def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
192129
"""Computes the language model loss (Cross entropy across vocabulary)
193130

megatron/core/models/gpt/gpt_model.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from megatron.core.quantization.utils import get_quant_config_or_none
2626
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
2727
from megatron.core.transformer.enums import CudaGraphScope, ModelType
28+
from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule
2829
from megatron.core.transformer.multi_token_prediction import (
2930
MTPLossAutoScaler,
3031
MTPLossLoggingHelper,
@@ -238,7 +239,7 @@ def __init__(
238239
self.embedding_activation_buffer = None
239240
self.grad_output_buffer = None
240241

241-
self.output_layer = tensor_parallel.ColumnParallelLinear(
242+
self.output_layer = LinearCrossEntropyModule(
242243
config.hidden_size,
243244
self.vocab_size,
244245
config=config,
@@ -633,16 +634,12 @@ def _postprocess(
633634
)
634635

635636
# Compute mtp loss without storing logits to save memory.
636-
mtp_loss = self.compute_output_layer_and_language_model_loss(
637-
hidden_states_list[mtp_layer_number + 1],
637+
mtp_loss = self.output_layer(
638+
output_cross_entropy_loss=True,
639+
input_=hidden_states_list[mtp_layer_number + 1],
640+
weight=output_weight,
638641
labels=mtp_labels,
639-
weight=self.shared_embedding_or_output_weight(),
640-
sequence_parallel_enabled=self.output_layer.sequence_parallel,
641-
column_parallel_linear=self.output_layer,
642-
col_linear_kwargs={
643-
'weight': output_weight,
644-
'runtime_gather_output': runtime_gather_output,
645-
},
642+
runtime_gather_output=runtime_gather_output,
646643
)
647644

648645
mtp_loss = loss_mask * mtp_loss
@@ -721,16 +718,12 @@ def _postprocess(
721718
# [s b h] => [b s h]
722719
return logits.transpose(0, 1).contiguous()
723720

724-
loss = self.compute_output_layer_and_language_model_loss(
725-
hidden_states,
721+
loss = self.output_layer(
722+
output_cross_entropy_loss=True,
723+
input_=hidden_states,
726724
labels=labels,
727-
weight=self.shared_embedding_or_output_weight(),
728-
sequence_parallel_enabled=self.output_layer.sequence_parallel,
729-
column_parallel_linear=self.output_layer,
730-
col_linear_kwargs={
731-
'weight': output_weight,
732-
'runtime_gather_output': runtime_gather_output,
733-
},
725+
weight=output_weight,
726+
runtime_gather_output=runtime_gather_output,
734727
)
735728

736729
return loss

megatron/core/models/mamba/mamba_model.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from torch import Tensor
66

7-
from megatron.core import tensor_parallel
87
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
98
from megatron.core.inference.contexts import BaseInferenceContext
109
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
@@ -16,6 +15,7 @@
1615
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
1716
from megatron.core.transformer import TransformerConfig
1817
from megatron.core.transformer.enums import ModelType
18+
from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule
1919
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
2020
from megatron.core.utils import (
2121
WrappedTensor,
@@ -136,7 +136,7 @@ def __init__(
136136

137137
# Output
138138
if post_process:
139-
self.output_layer = tensor_parallel.ColumnParallelLinear(
139+
self.output_layer = LinearCrossEntropyModule(
140140
config.hidden_size,
141141
self.vocab_size,
142142
config=config,
@@ -304,16 +304,12 @@ def forward(
304304
# [s b h] => [b s h]
305305
return logits.transpose(0, 1).contiguous()
306306

307-
loss = self.compute_output_layer_and_language_model_loss(
308-
hidden_states,
309-
labels,
310-
weight=self.shared_embedding_or_output_weight(),
311-
sequence_parallel_enabled=self.output_layer.sequence_parallel,
312-
column_parallel_linear=self.output_layer,
313-
col_linear_kwargs={
314-
"weight": output_weight,
315-
"runtime_gather_output": runtime_gather_output,
316-
},
307+
loss = self.output_layer(
308+
output_cross_entropy_loss=True,
309+
input_=hidden_states,
310+
labels=labels,
311+
weight=output_weight,
312+
runtime_gather_output=runtime_gather_output,
317313
)
318314

319315
return loss
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
3+
from typing import Literal, Optional, Tuple, Union
4+
5+
import torch
6+
7+
from megatron.core import tensor_parallel
8+
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
9+
from megatron.core.fusions.fused_linear_cross_entropy import linear_cross_entropy
10+
from megatron.core.transformer.enums import CudaGraphScope
11+
from megatron.core.utils import is_te_min_version
12+
13+
try:
14+
from megatron.core.extensions.transformer_engine import te_parallel_cross_entropy
15+
except:
16+
te_parallel_cross_entropy = None
17+
18+
19+
class LinearCrossEntropyModule(tensor_parallel.ColumnParallelLinear):
20+
"""
21+
A module that combines a ColumnParallelLinear layer with fused
22+
linear + cross-entropy loss computation over a tensor-parallel vocabulary.
23+
"""
24+
25+
def forward(
26+
self,
27+
input_: torch.Tensor,
28+
weight: Optional[torch.Tensor] = None,
29+
runtime_gather_output: Optional[bool] = None,
30+
output_cross_entropy_loss: bool = False,
31+
labels: Optional[torch.Tensor] = None,
32+
reduction: Literal["none", "sum", "mean"] = "none",
33+
ignore_index: int = -100,
34+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
35+
"""Run either the plain ColumnParallelLinear or fused linear+cross-entropy."""
36+
if output_cross_entropy_loss:
37+
assert labels is not None, "labels cannot be None when outputting cross-entropy loss."
38+
return self._compute_linear_and_cross_entropy_loss(
39+
hidden=input_,
40+
weight=weight if weight is not None else self.weight,
41+
labels=labels,
42+
reduction=reduction,
43+
ignore_index=ignore_index,
44+
)
45+
46+
# Fall back to standard ColumnParallelLinear forward.
47+
# ColumnParallelLinear.forward returns (output, bias) or just output
48+
# depending on configuration, so keep the return type as Tensor.
49+
return super().forward(input_, weight, runtime_gather_output)
50+
51+
def _compute_linear_and_cross_entropy_loss(
52+
self,
53+
hidden: torch.Tensor,
54+
weight: torch.Tensor,
55+
runtime_gather_output: Optional[bool] = None,
56+
labels: Optional[torch.Tensor] = None,
57+
reduction: Literal["none", "sum", "mean"] = "none",
58+
ignore_index: int = -100,
59+
) -> torch.Tensor:
60+
"""Compute fused linear + cross-entropy over tensor-parallel vocab."""
61+
if (
62+
self.config.cross_entropy_loss_fusion
63+
and self.config.cross_entropy_fusion_impl == 'linear'
64+
):
65+
assert (
66+
weight is not None
67+
), "weight cannot be None when using fused linear cross entropy."
68+
assert (
69+
labels is not None
70+
), "labels cannot be None when using fused linear cross entropy."
71+
72+
# [b s] => [s b]
73+
labels = labels.transpose(0, 1).contiguous()
74+
loss = linear_cross_entropy(
75+
hidden,
76+
self.weight,
77+
labels,
78+
sequence_parallel=self.sequence_parallel,
79+
reduction=reduction,
80+
ignore_index=ignore_index,
81+
tp_group=self.tp_group,
82+
)
83+
# If reduction != "none" this will be a scalar; for "none" it should
84+
# match [s, b] and can be reshaped back to [b, s].
85+
if reduction == "none":
86+
loss = loss.view_as(labels).transpose(0, 1).contiguous()
87+
else:
88+
logits, _ = super().forward(hidden, weight, runtime_gather_output)
89+
loss = self._compute_cross_entropy_loss(labels, logits)
90+
91+
return loss
92+
93+
def _compute_cross_entropy_loss(
94+
self, labels: torch.Tensor, logits: torch.Tensor
95+
) -> Optional[torch.Tensor]:
96+
"""Compute (possibly fused) vocab-parallel cross-entropy loss."""
97+
loss = None
98+
99+
# [b s] => [s b]
100+
labels = labels.transpose(0, 1).contiguous()
101+
if self.config.cross_entropy_loss_fusion:
102+
if self.config.cross_entropy_fusion_impl == 'te':
103+
if te_parallel_cross_entropy is not None:
104+
labels = torch.as_strided(labels, labels.size(), (labels.size()[1], 1))
105+
# Use is_cg_capturable=True for full iteration CUDA graphs
106+
# to avoid torch.equal checks
107+
is_cg_capturable = (
108+
hasattr(self.config, 'cuda_graph_scope')
109+
and CudaGraphScope.full_iteration in self.config.cuda_graph_scope
110+
)
111+
if is_cg_capturable and not is_te_min_version("2.7.0"):
112+
from megatron.core.utils import get_te_version
113+
114+
current_version = get_te_version()
115+
raise AssertionError(
116+
f"CUDA graph compatible cross entropy requires "
117+
f"TransformerEngine >= 2.7.0, but found version {current_version}. "
118+
"Please upgrade TransformerEngine "
119+
f"or set cuda_graph_scope to a value other than 'full_iteration'."
120+
)
121+
122+
loss = te_parallel_cross_entropy(
123+
logits, labels, self.tp_group, is_cg_capturable
124+
)
125+
else:
126+
raise RuntimeError("Trying to use a TE block when it's not present.")
127+
elif self.config.cross_entropy_fusion_impl == 'native':
128+
loss = fused_vocab_parallel_cross_entropy(logits, labels, self.tp_group)
129+
else:
130+
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
131+
132+
# [s b] => [b, s]
133+
loss = loss.transpose(0, 1).contiguous()
134+
return loss

0 commit comments

Comments
 (0)