Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cross entropy _input -> input rename
Signed-off-by: Pawel Gadzinski <[email protected]>
  • Loading branch information
pggPL committed Nov 4, 2025
commit 19da61bcdfeb1b1e715b467362f855a10f8c75af
3 changes: 1 addition & 2 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context

.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy

Recipe availability
-------------------
Expand Down Expand Up @@ -80,8 +81,6 @@ Mixture of Experts (MoE) functions

.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index

.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy

.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs


Expand Down
37 changes: 25 additions & 12 deletions transformer_engine/pytorch/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Cross Entropy Loss API"""

from typing import Optional
import warnings

import torch

Expand All @@ -25,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
_input,
input,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw @ptrendx's comment above suggesting we should change it, but input is also a python keyword and we should avoid using it. Maybe we can go with inp as we do in our other modules. I'm surprised that the linter didn't complain about this change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point actually.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to inp

target,
label_smoothing=0.0,
reduce_loss=False,
Expand All @@ -39,7 +40,7 @@ def forward(

Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1].
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
Expand All @@ -49,16 +50,16 @@ def forward(
Returns:
tensor: The computed loss.
"""
loss, _input = triton_cross_entropy.cross_entropy_forward(
_input,
loss, input = triton_cross_entropy.cross_entropy_forward(
input,
target,
label_smoothing,
reduce_loss,
dist_process_group,
ignore_idx,
)

ctx.save_for_backward(_input.detach())
ctx.save_for_backward(input.detach())
ctx.is_cg_capturable = is_cg_capturable
return loss

Expand All @@ -74,12 +75,12 @@ def backward(ctx, grad_output):
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
(_input,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward(
_input, grad_output, ctx.is_cg_capturable
(input,) = ctx.saved_tensors
input = triton_cross_entropy.cross_entropy_backward(
input, grad_output, ctx.is_cg_capturable
)
return (
_input,
input,
None,
None,
None,
Expand All @@ -90,13 +91,15 @@ def backward(ctx, grad_output):


def parallel_cross_entropy(
_input: torch.Tensor,
input: torch.Tensor,
target: torch.Tensor,
label_smoothing: float = 0.0,
reduce_loss: bool = False,
dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
ignore_idx: int = -100,
is_cg_capturable: bool = False,
*,
_input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Cross Entropy loss with optional distributed reduction.
Expand All @@ -111,7 +114,7 @@ def parallel_cross_entropy(

Parameters
----------
_input : torch.Tensor
input : torch.Tensor
The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size,
SQ is sequence length, V is vocab size.
target : torch.Tensor
Expand All @@ -132,8 +135,18 @@ def parallel_cross_entropy(
torch.Tensor
The computed loss.
"""
# Handle backward compatibility with _input parameter
if _input is not None:
warnings.warn(
"The '_input' parameter is deprecated and will be removed in a future version. "
"Please use 'input' instead.",
FutureWarning,
stacklevel=2,
)
input = _input

return CrossEntropyFunction.apply(
_input,
input,
target,
label_smoothing,
reduce_loss,
Expand Down
Loading