Skip to content
This repository was archived by the owner on Feb 18, 2026. It is now read-only.
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove some registers
  • Loading branch information
woct0rdho committed Aug 31, 2025
commit 7a6a98feec6b6b1cd1425e536da3e58c6bee9e40
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ static const Fp8ConversionDesc Fp16_to_Fp8E5M2_RTNE(bool hasNativeFP) {
ret = {"{ \n"
".reg .b32 a<2>, b<2>; \n"
"add.u32 a0, $1, 0x007f007f; \n" // Round to nearest even:
"add.u32 a1, $2, 0x007f007f; \n" // If LSB of kept mantissa is 1
"and.b32 b0, $1, 0x01000100; \n" // then add 0x80 to mantissa
"and.b32 b1, $2, 0x01000100; \n" // else add 0x7f to mantissa
"add.u32 a1, $2, 0x007f007f; \n" // If LSB of fp8 mantissa is 1
"and.b32 b0, $1, 0x01000100; \n" // then add 0x80 to fp16 mantissa
"and.b32 b1, $2, 0x01000100; \n" // else add 0x7f to fp16 mantissa
"shr.b32 b0, b0, 8; \n"
"shr.b32 b1, b1, 8; \n"
"add.u32 a0, a0, b0; \n"
Expand Down Expand Up @@ -71,24 +71,23 @@ static const Fp8ConversionDesc Fp8E5M2_to_Bf16(bool hasNativeFP) {
// TODO: Handle inf and NaN
ret = {
"{ \n"
".reg .b32 a<2>, b<2>, c<4>, d<4>, e112; \n" // if input = 0xf1f2f3f4
"mov.u32 e112, 0x77800000; \n"
".reg .b32 a<2>, b<2>, c<4>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift to bf16 position
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift to bf16 position
"and.b32 c0, b0, 0xffff0000; \n" // c0 = f3
"shl.b32 c1, b0, 16; \n" // c1 = f4
"and.b32 c2, b1, 0xffff0000; \n" // c2 = f1
"shl.b32 c3, b1, 16; \n" // c3 = f2
"mul.f32 d0, c0, e112; \n" // move exponent bias
"mul.f32 d1, c1, e112; \n" // from 15 to 127
"mul.f32 d2, c2, e112; \n"
"mul.f32 d3, c3, e112; \n"
"prmt.b32 b0, d0, d1, 0x3276; \n" // b0 = 0xd3d4
"prmt.b32 b1, d2, d3, 0x3276; \n" // b1 = 0xd1d2
"mul.f32 c0, c0, 0x77800000; \n" // move exponent bias
"mul.f32 c1, c1, 0x77800000; \n" // from 15 to 127
"mul.f32 c2, c2, 0x77800000; \n"
"mul.f32 c3, c3, 0x77800000; \n"
"prmt.b32 b0, c0, c1, 0x3276; \n" // b0 = 0xc0c1
"prmt.b32 b1, c2, c3, 0x3276; \n" // b1 = 0xc2c3
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0=b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}",
Expand All @@ -97,18 +96,16 @@ static const Fp8ConversionDesc Fp8E5M2_to_Bf16(bool hasNativeFP) {
ret = {
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
".reg .b32 e112; \n" // 2**112 represented as
"mov.u32 e112, 0x77807780; \n" // bf16x2
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"lop3.b32 b0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 b1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"mul.rn.bf16x2 $0, b0, e112; \n" // b0.exp += 2**7-2**4
"mul.rn.bf16x2 $1, b1, e112; \n" // exponent compensate = 112
"mul.rn.bf16x2 $0, b0, 0x77807780; \n" // b0.exp += 2**7-2**4
"mul.rn.bf16x2 $1, b1, 0x77807780; \n" // exponent compensate = 112
"}",
32, 32, 4};
}
Expand Down Expand Up @@ -183,8 +180,7 @@ static const Fp8ConversionDesc Fp8E4M3Nv_to_Fp16(bool hasNativeFP) {
// TODO: Handle NaN
ret = {
"{ \n"
".reg .b32 a<2>, b<2>, c<4>, d<4>, e8; \n" // if input = 0xf1f2f3f4
"mov.u32 e8, 0x43800000; \n"
".reg .b32 a<2>, b<2>, c<4>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
Expand All @@ -195,12 +191,12 @@ static const Fp8ConversionDesc Fp8E4M3Nv_to_Fp16(bool hasNativeFP) {
"shl.b32 c1, b0, 16; \n" // c1 = f4
"and.b32 c2, b1, 0xffff0000; \n" // c2 = f1
"shl.b32 c3, b1, 16; \n" // c3 = f2
"mul.f32 d0, c0, e8; \n" // move exponent bias
"mul.f32 d1, c1, e8; \n" // from 7 to 15
"mul.f32 d2, c2, e8; \n"
"mul.f32 d3, c3, e8; \n"
"prmt.b32 b0, d0, d1, 0x3276; \n" // b0 = 0xd0d1
"prmt.b32 b1, d2, d3, 0x3276; \n" // b1 = 0xd2d3
"mul.f32 c0, c0, 0x43800000; \n" // move exponent bias
"mul.f32 c1, c1, 0x43800000; \n" // from 7 to 15
"mul.f32 c2, c2, 0x43800000; \n"
"mul.f32 c3, c3, 0x43800000; \n"
"prmt.b32 b0, c0, c1, 0x3276; \n" // b0 = 0xc0c1
"prmt.b32 b1, c2, c3, 0x3276; \n" // b1 = 0xc2c3
"shl.b32 b0, b0, 3; \n" // b0 <<= 3
"shl.b32 b1, b1, 3; \n" // shift to fp16 position
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0=b0|(0x80008000&a0)
Expand Down Expand Up @@ -246,8 +242,7 @@ static const Fp8ConversionDesc Fp8E4M3Nv_to_Bf16(bool hasNativeFP8,
// TODO: Handle NaN
ret = {
"{ \n"
".reg .b32 a<2>, b<2>, c<4>, d<4>, e120; \n" // if input = 0xf1f2f3f4
"mov.u32 e120, 0x7b800000; \n"
".reg .b32 a<2>, b<2>, c<4>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
Expand All @@ -258,12 +253,12 @@ static const Fp8ConversionDesc Fp8E4M3Nv_to_Bf16(bool hasNativeFP8,
"shl.b32 c1, b0, 16; \n" // c1 = f4
"and.b32 c2, b1, 0xffff0000; \n" // c2 = f1
"shl.b32 c3, b1, 16; \n" // c3 = f2
"mul.f32 d0, c0, e120; \n" // move exponent bias
"mul.f32 d1, c1, e120; \n" // from 7 to 127
"mul.f32 d2, c2, e120; \n"
"mul.f32 d3, c3, e120; \n"
"prmt.b32 b0, d0, d1, 0x3276; \n" // b0 = 0xd0d1
"prmt.b32 b1, d2, d3, 0x3276; \n" // b1 = 0xd2d3
"mul.f32 c0, c0, 0x7b800000; \n" // move exponent bias
"mul.f32 c1, c1, 0x7b800000; \n" // from 7 to 127
"mul.f32 c2, c2, 0x7b800000; \n"
"mul.f32 c3, c3, 0x7b800000; \n"
"prmt.b32 b0, c0, c1, 0x3276; \n" // b0 = 0xc0c1
"prmt.b32 b1, c2, c3, 0x3276; \n" // b1 = 0xc2c3
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0=b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}",
Expand Down