Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes
Signed-off-by: Pawel Gadzinski <[email protected]>
  • Loading branch information
pggPL committed Oct 23, 2025
commit 4b4802666704fcf93ed11254363885d86a632b75
3 changes: 2 additions & 1 deletion docs/debug/1_getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
- log the statistics for each of the tensors in every matrix multiply (GEMM) operation,
- run selected GEMMs in higher precision,
- run current scaling - with one scaling factor per tensor - for particular GEMMs,
- test new precisions and integrate them with FP8 training,
- test new precisions and integrate them with quantized training (FP8, NVFP4, etc.),
- monitor quantization errors and underflows for different precision formats,
- ... and many more.

There are 4 things one needs to do to use Transformer Engine debug features:
Expand Down
7 changes: 5 additions & 2 deletions docs/debug/3_api_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ Debug features

.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
.. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
5 changes: 4 additions & 1 deletion transformer_engine/debug/features/log_fp8_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer

import transformer_engine_torch as tex

try:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer

Expand Down Expand Up @@ -210,6 +212,7 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str):

def get_recipe_from_stat(self, stat: str, default_recipe: str = ""):
"""Returns the recipe name from the stat string."""

columnwise_stat = stat.endswith("_columnwise")
for recipe_name in ALL_RECIPE_NAMES:
if recipe_name in stat:
Expand All @@ -234,7 +237,7 @@ def update_aux_dict(
Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
"""
fp8_dtype = None
fp8_dtype = tex.DType.kFloat8E4M3
if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]:
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer)
Expand Down
11 changes: 3 additions & 8 deletions transformer_engine/debug/features/log_nvfp4_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,11 @@
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter

from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage

@Registry.register_feature(namespace="transformer_engine")
class LogNvfp4TensorStats(BaseLogTensorStats):
"""
Logs statistics of NVFP4 quantized tensors.

This feature is specifically designed for NVFP4 quantization and provides:
- underflows%: percentage of non-zero elements clipped to 0 after quantization (computed from packed FP4 data)
- mse: mean squared error between original and quantized-dequantized tensor
"""Logs statistics of NVFP4 quantized tensors.

In distributed runs each rank first computes its local statistics; the values
are gathered the next time `debug_api.step()` is called. Remember to call
Expand Down Expand Up @@ -170,7 +165,7 @@ def inspect_tensor(
)

assert isinstance(
quantized_tensor, QuantizedTensor
quantized_tensor, NVFP4TensorStorage
), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a QuantizedTensor."

for stat in config["stats"]:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def update_quantized(

def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
return tex.quantize(tensor, self, None)

def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
Expand Down
Loading