Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Nov 7, 2025

Description

Step 1 in a multi-step process to have Jax execute Triton kernels to support MOE
Steps:

  • Move Triton kernels to common
    then:
  • Create C++ wrapper that executes Triton kernels
  • Register this C++ wrapper as an XLA custom call for JAX
    OR
  • Use jax-triton to call the triton kernels (needs experimentation)

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:
Relocate triton kernels to transformerengine/common/triton to be shared by both JAX and PyTorch (pure triton kernels without any use or import of pytorch)
Keep the Pytorch wrapper of these kernels still in transformerengine/pytorch/triton (new files containing fucntions that call the triton kernels but uses torch tensors)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung
Copy link
Collaborator Author

tdophung commented Nov 7, 2025

/te_ci L2 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Relocates Triton kernel implementations from transformer_engine/pytorch/triton/ to transformer_engine/common/triton/ to enable sharing between PyTorch and JAX backends.

Key Changes:

  • Moved 4 files (__init__.py, cross_entropy.py, pad.py, permutation.py) from PyTorch-specific directory to common directory
  • Updated 3 PyTorch import statements to reference the new location:
    • transformer_engine/pytorch/cross_entropy.py:9
    • transformer_engine/pytorch/distributed.py:49
    • transformer_engine/pytorch/permutation.py:11
  • No changes to kernel implementation code itself - pure file relocation
  • All import paths correctly updated with no remaining references to old location

Impact:
This is a foundational refactoring to support future MOE functionality in JAX by making Triton kernels framework-agnostic.

Confidence Score: 5/5

  • This PR is safe to merge with no risk - pure code relocation with correct import updates
  • Perfect score because this is a straightforward file move with no logic changes. All imports are correctly updated, no references to old paths remain, and the kernel implementations are unchanged. The refactoring is clean and complete.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/triton/cross_entropy.py 5/5 Moved from pytorch/triton - cross entropy Triton kernels, no code changes
transformer_engine/common/triton/pad.py 5/5 Moved from pytorch/triton - zero padding kernels, no code changes
transformer_engine/common/triton/permutation.py 5/5 Moved from pytorch/triton - MoE permutation kernels, no code changes
transformer_engine/pytorch/cross_entropy.py 5/5 Updated import from pytorch.triton.cross_entropy to common.triton.cross_entropy
transformer_engine/pytorch/distributed.py 5/5 Updated import from .triton.pad to transformer_engine.common.triton.pad
transformer_engine/pytorch/permutation.py 5/5 Updated import from pytorch.triton.permutation to common.triton.permutation

Sequence Diagram

sequenceDiagram
    participant PT as PyTorch Modules
    participant Common as transformer_engine/common/triton
    participant Kernels as Triton Kernels

    Note over PT,Kernels: Before: PyTorch modules imported from pytorch/triton
    PT->>PT: import pytorch.triton.cross_entropy
    PT->>PT: import pytorch.triton.permutation
    PT->>PT: import pytorch.triton.pad

    Note over PT,Kernels: After: Modules moved to common, imports updated
    PT->>Common: import common.triton.cross_entropy
    PT->>Common: import common.triton.permutation
    PT->>Common: import common.triton.pad
    
    Note over Common: Triton kernels now in common<br/>Ready for JAX integration
    Common->>Kernels: Execute Triton kernels
    Kernels-->>Common: Return results
    Common-->>PT: Return results
Loading

7 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@tdophung
Copy link
Collaborator Author

tdophung commented Nov 7, 2025

/te_ci pytorch

Signed-off-by: tdophung <[email protected]>
@tdophung
Copy link
Collaborator Author

tdophung commented Nov 7, 2025

/te_ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR relocates Triton kernel implementations from transformer_engine/pytorch/triton to transformer_engine/common/triton as part of a multi-step initiative to enable JAX MOE support. The refactoring is clean and maintains backward compatibility:

  • New files created: Three kernel modules (cross_entropy.py, permutation.py, pad.py) and an __init__.py moved to transformer_engine/common/triton/
  • PyTorch layer refactored: The pytorch/triton modules now import kernels from common and provide PyTorch-specific wrapper functions
  • No logic changes: All kernel code is identical to the original - this is a pure code relocation
  • Import chain preserved: External code importing from transformer_engine.pytorch is unaffected
  • Clean separation: Common layer contains pure Triton kernels, PyTorch layer handles torch tensors and distributed logic

Confidence Score: 5/5

  • This PR is safe to merge - it's a clean code refactoring with no logic changes
  • This is a well-executed refactoring with zero logic changes. The Triton kernels are moved verbatim to common, PyTorch wrappers correctly import from the new location, and all external APIs remain unchanged. The separation enables future JAX integration without affecting existing PyTorch functionality.
  • No files require special attention - all changes are straightforward code relocations

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/triton/cross_entropy.py 5/5 Triton kernel implementations moved from pytorch to common - identical code, no logic changes
transformer_engine/common/triton/pad.py 5/5 Triton padding kernel moved from pytorch to common - identical code, no logic changes
transformer_engine/common/triton/permutation.py 5/5 Triton permutation kernels moved from pytorch to common - identical code, no logic changes
transformer_engine/pytorch/triton/cross_entropy.py 5/5 Refactored to import kernels from common and provide PyTorch wrapper functions
transformer_engine/pytorch/triton/pad.py 5/5 Refactored to import padding kernel from common and provide PyTorch wrapper
transformer_engine/pytorch/triton/permutation.py 5/5 Refactored to import permutation kernels from common and provide PyTorch wrappers

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Layer (Future)
    participant PyTorch as PyTorch Layer
    participant Common as Common Triton Kernels
    participant GPU as GPU/Triton Runtime

    Note over PyTorch,Common: Before PR: PyTorch had kernels
    Note over Common,JAX: After PR: Kernels moved to common

    PyTorch->>Common: Import cross_entropy kernels
    PyTorch->>Common: Import permutation kernels
    PyTorch->>Common: Import pad kernels
    
    Note over PyTorch: PyTorch provides<br/>wrapper functions<br/>with torch-specific logic
    
    PyTorch->>Common: Call online_softmax_kernel
    Common->>GPU: Execute Triton kernel
    GPU-->>Common: Return result
    Common-->>PyTorch: Return result
    
    Note over JAX,Common: Future: JAX will also<br/>import from common<br/>for MOE support
Loading

10 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant