-
Notifications
You must be signed in to change notification settings - Fork 540
Move Triton to common #2359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Move Triton to common #2359
Conversation
Signed-off-by: tdophung <[email protected]>
|
/te_ci L2 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Relocates Triton kernel implementations from transformer_engine/pytorch/triton/ to transformer_engine/common/triton/ to enable sharing between PyTorch and JAX backends.
Key Changes:
- Moved 4 files (
__init__.py,cross_entropy.py,pad.py,permutation.py) from PyTorch-specific directory to common directory - Updated 3 PyTorch import statements to reference the new location:
transformer_engine/pytorch/cross_entropy.py:9transformer_engine/pytorch/distributed.py:49transformer_engine/pytorch/permutation.py:11
- No changes to kernel implementation code itself - pure file relocation
- All import paths correctly updated with no remaining references to old location
Impact:
This is a foundational refactoring to support future MOE functionality in JAX by making Triton kernels framework-agnostic.
Confidence Score: 5/5
- This PR is safe to merge with no risk - pure code relocation with correct import updates
- Perfect score because this is a straightforward file move with no logic changes. All imports are correctly updated, no references to old paths remain, and the kernel implementations are unchanged. The refactoring is clean and complete.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/triton/cross_entropy.py | 5/5 | Moved from pytorch/triton - cross entropy Triton kernels, no code changes |
| transformer_engine/common/triton/pad.py | 5/5 | Moved from pytorch/triton - zero padding kernels, no code changes |
| transformer_engine/common/triton/permutation.py | 5/5 | Moved from pytorch/triton - MoE permutation kernels, no code changes |
| transformer_engine/pytorch/cross_entropy.py | 5/5 | Updated import from pytorch.triton.cross_entropy to common.triton.cross_entropy |
| transformer_engine/pytorch/distributed.py | 5/5 | Updated import from .triton.pad to transformer_engine.common.triton.pad |
| transformer_engine/pytorch/permutation.py | 5/5 | Updated import from pytorch.triton.permutation to common.triton.permutation |
Sequence Diagram
sequenceDiagram
participant PT as PyTorch Modules
participant Common as transformer_engine/common/triton
participant Kernels as Triton Kernels
Note over PT,Kernels: Before: PyTorch modules imported from pytorch/triton
PT->>PT: import pytorch.triton.cross_entropy
PT->>PT: import pytorch.triton.permutation
PT->>PT: import pytorch.triton.pad
Note over PT,Kernels: After: Modules moved to common, imports updated
PT->>Common: import common.triton.cross_entropy
PT->>Common: import common.triton.permutation
PT->>Common: import common.triton.pad
Note over Common: Triton kernels now in common<br/>Ready for JAX integration
Common->>Kernels: Execute Triton kernels
Kernels-->>Common: Return results
Common-->>PT: Return results
7 files reviewed, no comments
|
/te_ci pytorch |
Signed-off-by: tdophung <[email protected]>
|
/te_ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR relocates Triton kernel implementations from transformer_engine/pytorch/triton to transformer_engine/common/triton as part of a multi-step initiative to enable JAX MOE support. The refactoring is clean and maintains backward compatibility:
- New files created: Three kernel modules (
cross_entropy.py,permutation.py,pad.py) and an__init__.pymoved totransformer_engine/common/triton/ - PyTorch layer refactored: The pytorch/triton modules now import kernels from common and provide PyTorch-specific wrapper functions
- No logic changes: All kernel code is identical to the original - this is a pure code relocation
- Import chain preserved: External code importing from
transformer_engine.pytorchis unaffected - Clean separation: Common layer contains pure Triton kernels, PyTorch layer handles torch tensors and distributed logic
Confidence Score: 5/5
- This PR is safe to merge - it's a clean code refactoring with no logic changes
- This is a well-executed refactoring with zero logic changes. The Triton kernels are moved verbatim to common, PyTorch wrappers correctly import from the new location, and all external APIs remain unchanged. The separation enables future JAX integration without affecting existing PyTorch functionality.
- No files require special attention - all changes are straightforward code relocations
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/triton/cross_entropy.py | 5/5 | Triton kernel implementations moved from pytorch to common - identical code, no logic changes |
| transformer_engine/common/triton/pad.py | 5/5 | Triton padding kernel moved from pytorch to common - identical code, no logic changes |
| transformer_engine/common/triton/permutation.py | 5/5 | Triton permutation kernels moved from pytorch to common - identical code, no logic changes |
| transformer_engine/pytorch/triton/cross_entropy.py | 5/5 | Refactored to import kernels from common and provide PyTorch wrapper functions |
| transformer_engine/pytorch/triton/pad.py | 5/5 | Refactored to import padding kernel from common and provide PyTorch wrapper |
| transformer_engine/pytorch/triton/permutation.py | 5/5 | Refactored to import permutation kernels from common and provide PyTorch wrappers |
Sequence Diagram
sequenceDiagram
participant JAX as JAX Layer (Future)
participant PyTorch as PyTorch Layer
participant Common as Common Triton Kernels
participant GPU as GPU/Triton Runtime
Note over PyTorch,Common: Before PR: PyTorch had kernels
Note over Common,JAX: After PR: Kernels moved to common
PyTorch->>Common: Import cross_entropy kernels
PyTorch->>Common: Import permutation kernels
PyTorch->>Common: Import pad kernels
Note over PyTorch: PyTorch provides<br/>wrapper functions<br/>with torch-specific logic
PyTorch->>Common: Call online_softmax_kernel
Common->>GPU: Execute Triton kernel
GPU-->>Common: Return result
Common-->>PyTorch: Return result
Note over JAX,Common: Future: JAX will also<br/>import from common<br/>for MOE support
10 files reviewed, no comments
Description
Step 1 in a multi-step process to have Jax execute Triton kernels to support MOE
Steps:
then:
OR
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Relocate triton kernels to transformerengine/common/triton to be shared by both JAX and PyTorch (pure triton kernels without any use or import of pytorch)
Keep the Pytorch wrapper of these kernels still in transformerengine/pytorch/triton (new files containing fucntions that call the triton kernels but uses torch tensors)
Checklist: