Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
some linting
  • Loading branch information
lessw2020 committed Jun 22, 2025
commit 50ebdd2690b8aaef953a217b550960a03eb105c5
2 changes: 1 addition & 1 deletion torchtitan/experiments/deepseek_v3/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from model_config import deepseek_config_registry
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from transformers import AutoTokenizer

from torchtitan.tools.utils import Color
from transformers import AutoTokenizer

# Uncomment the model you want to run.
model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
Expand Down
16 changes: 12 additions & 4 deletions torchtitan/experiments/kernels/blackwell/cute_grouped_gemm_fwd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Strategy using CUTLASS GroupedGemmKernel for group GEMM operations on Blackwell.

Expand Down Expand Up @@ -467,10 +473,12 @@ def _execute_cutlass_kernel(
num_groups = len(problem_sizes)

# Convert to CUTE tensors using improved converter
problem_sizes_cute, strides_cute, ptrs_cute = (
self.converter.create_metadata_tensors(
problem_sizes, strides_abc, ptrs_abc, device
)
(
problem_sizes_cute,
strides_cute,
ptrs_cute,
) = self.converter.create_metadata_tensors(
problem_sizes, strides_abc, ptrs_abc, device
)

# Get other required components
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/experiments/kernels/blackwell/group_gemm_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Dict, List, Optional, Tuple

Expand Down
Loading