Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: 'Install dependencies'
run: |
pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2
pip install breathe==4.35.0 sphinx-autoapi==3.3.2
pip install breathe==4.35.0 sphinx-autoapi==3.3.2 sphinx-tabs==3.4.7
sudo apt-get install -y pandoc graphviz doxygen
export GIT_SHA=$(git show-ref --hash HEAD)
- name: 'Build docs'
Expand Down
132 changes: 132 additions & 0 deletions docs/_static/css/diagram-colors.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/* Diagram color definitions for Transformer Engine documentation */

/* High precision (BF16/FP16) elements */
.hp {
fill: #ede7f6;
stroke: #673ab7;
stroke-width: 2;
}

/* FP8 precision elements */
.fp8 {
fill: #fff8e1;
stroke: #ffa726;
stroke-width: 2;
}

/* GEMM/computation operations */
.gemm {
fill: #ffe0b2;
stroke: #fb8c00;
stroke-width: 2.5;
}

/* Quantization operations */
.quantize {
fill: #e8f5e9;
stroke: #66bb6a;
stroke-width: 2;
}

/* Amax computation operations */
.amax {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}

/* Text styles */
.text {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 14px;
text-anchor: middle;
fill: #212121;
}

.small-text {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 11px;
text-anchor: middle;
fill: #757575;
}

.label {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 11px;
text-anchor: middle;
fill: #424242;
}

.title {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 18px;
font-weight: 600;
text-anchor: middle;
fill: #212121;
}

.section-title {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 15px;
font-weight: 600;
text-anchor: middle;
}

/* Arrows */
.arrow {
stroke: #616161;
stroke-width: 2;
fill: none;
}

/* Additional box and element styles */
.box-blue {
fill: #e3f2fd;
stroke: #1976d2;
stroke-width: 2;
}

.box-orange {
fill: #fff3e0;
stroke: #f57c00;
stroke-width: 2;
}

.box-green {
fill: #c8e6c9;
stroke: #388e3c;
stroke-width: 2;
}

.box-dashed {
stroke-dasharray: 5,5;
}

/* LayerNorm specific */
.layernorm {
fill: #b3e5fc;
stroke: #0277bd;
stroke-width: 2.5;
}

/* Fused layers */
.fused {
fill: #b2dfdb;
stroke: #00695c;
stroke-width: 3;
}

/* Generic computation blocks */
.computation {
fill: #f5f5f5;
stroke: #757575;
stroke-width: 2;
}

/* FP32 precision (alternative red) */
.fp32 {
fill: #ffcdd2;
stroke: #d32f2f;
stroke-width: 2.5;
}

43 changes: 43 additions & 0 deletions docs/_static/css/sphinx_tabs.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/* Custom styling for sphinx-tabs */

.sphinx-tabs {
margin-bottom: 1rem;
}

.sphinx-tabs-tab {
background-color: #f4f4f4;
border: 1px solid #ccc;
border-bottom: none;
padding: 0.5rem 1rem;
margin-right: 0.25rem;
cursor: pointer;
font-weight: 500;
transition: background-color 0.2s;
}

.sphinx-tabs-tab:hover {
background-color: #e0e0e0;
}

.sphinx-tabs-tab[aria-selected="true"] {
background-color: #76b900; /* NVIDIA green */
color: white;
border-color: #76b900;
}

.sphinx-tabs-panel {
border: 1px solid #ccc;
padding: 1rem;
background-color: #f9f9f9;
}

/* Dark mode support for RTD theme */
.rst-content .sphinx-tabs-tab {
color: #333;
}

.rst-content .sphinx-tabs-tab[aria-selected="true"] {
color: white;
}


4 changes: 4 additions & 0 deletions docs/_templates/layout.html
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
overflow: visible !important;
}

.quant {
background-color: yellow !important;
}

</style>
<style>
a:link, a:visited {
Expand Down
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"nbsphinx",
"breathe",
"autoapi.extension",
"sphinx_tabs.tabs",
]

templates_path = ["_templates"]
Expand All @@ -79,6 +80,8 @@
html_css_files = [
"css/nvidia_font.css",
"css/nvidia_footer.css",
"css/diagram-colors.css",
"css/sphinx_tabs.css",
]

html_theme_options = {
Expand Down
85 changes: 85 additions & 0 deletions docs/debug/1_getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,88 @@ Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_
:align: center

Fig 2: TensorBoard with plotted stats.

Real-world Examples
-------------------

Below are complete working examples from the repository that demonstrate various use cases of Transformer Engine.

.. tabs::

.. tab:: MNIST Training

This example shows how to integrate Transformer Engine into a standard PyTorch training script.

Model Definition
^^^^^^^^^^^^^^^^

First, we define our neural network. Notice how we use ``te.Linear`` instead of ``nn.Linear``
for the fully connected layers that will benefit from FP8 computation:

.. literalinclude:: ../../examples/pytorch/mnist/main.py
:language: python
:lines: 16-29
:linenos:
:emphasize-lines: 8-9

The ``use_te`` flag allows us to switch between standard PyTorch Linear layers and
Transformer Engine Linear layers for benchmarking.

Training Loop
^^^^^^^^^^^^^

The training function wraps the forward pass in ``te.autocast`` to enable FP8:

.. literalinclude:: ../../examples/pytorch/mnist/main.py
:language: python
:lines: 49-68
:linenos:
:emphasize-lines: 6-7

Key points:

- Forward pass is inside ``te.autocast(enabled=use_fp8)``
- Backward pass happens outside autocast (it still uses FP8 from forward)
- Standard PyTorch optimizer works without modification

Inference
^^^^^^^^^

During inference, we use the same autocast pattern:

.. literalinclude:: ../../examples/pytorch/mnist/main.py
:language: python
:lines: 83-95
:linenos:
:emphasize-lines: 5-6

.. tab:: FSDP Integration

This example demonstrates using Transformer Engine with PyTorch's Fully Sharded Data Parallel (FSDP).

Complete Example
^^^^^^^^^^^^^^^^

.. literalinclude:: ../../examples/pytorch/fsdp/fsdp.py
:language: python
:lines: 1-50
:linenos:

FSDP wraps the Transformer Engine model and handles distributed training automatically.
The FP8 precision works seamlessly with FSDP's sharding strategy.

.. tab:: Communication Overlap

Advanced example showing how to overlap communication with computation for better performance
in distributed training scenarios.

Overlap Pattern
^^^^^^^^^^^^^^^

.. literalinclude:: ../../examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
:language: python
:lines: 1-40
:linenos:

This pattern allows gradient communication to happen concurrently with backward computation,
reducing the overall training time in multi-GPU setups.
Loading
Loading