Skip to content
Prev Previous commit
Next Next commit
update torchao api name
  • Loading branch information
danielvegamyhre committed Oct 1, 2025
commit 4d51ae25229ac9e79ab759bc40517e6de09ee343
4 changes: 2 additions & 2 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def _get_a2a_func(self, a2a_impl: str):
elif a2a_impl == "mxfp8":
logger.info("Using mxfp8 all-to-all implementation")
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
mxfp8_sync_all_to_all_v,
to_mxfp8_a2a_dequant,
)

return mxfp8_sync_all_to_all_v
return to_mxfp8_a2a_dequant
else:
raise ValueError(f"Unknown a2a_impl: {a2a_impl}")

Expand Down