Skip to content

Commit 84040e3

Browse files
committed
define cluster_ctarank
1 parent 1c94bab commit 84040e3

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

include/common/util.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,21 @@ struct shared_allocator {
285285
*/
286286
using tma_allocator = shared_allocator<1024>;
287287
using tma_swizzle_allocator = tma_allocator; // swizzled TMA modes require up to 1024 byte alignments :/
288+
289+
/* Get CTA ID within a cluster */
290+
__device__ static inline int3 clusterIdx() {
291+
int3 cluster_idx;
292+
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(cluster_idx.x));
293+
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(cluster_idx.y));
294+
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(cluster_idx.z));
295+
return cluster_idx;
296+
}
297+
__device__ static inline int cluster_ctarank() {
298+
uint32_t ctarank;
299+
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(ctarank));
300+
return ctarank;
301+
}
302+
288303
#endif
289304

290305
} // namespace kittens

0 commit comments

Comments
 (0)