-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[WIP]gml-hexagon: Q4_0 mm opt #17907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 25 commits
407b408
4ddb8a4
cfca78b
e9a02fd
e324bb0
0121291
010039a
a6ef41f
0376146
8abecfa
b567413
7c8f101
3a70465
3b0cef4
121e656
401fd3e
cf491f2
3a01d82
87ad8b2
421d031
bd43860
b197464
309d782
dbe9309
00d5fb3
b54ff18
f757245
09c4899
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -307,6 +307,50 @@ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, | |
| return hvx_vec_rmpy_x8_n(x, y, 1024); | ||
| } | ||
|
|
||
| static inline HVX_Vector_x2 hvx_vec_load_and_mul_d_rx2(const uint8_t * restrict r0_x_d, | ||
| const uint8_t * restrict r1_x_d, | ||
| const uint8_t * restrict y_d, | ||
| const HVX_Vector rd_mask) { | ||
| HVX_Vector vy_d = *(const HVX_Vector *) y_d; | ||
| HVX_Vector r0_d = *(const HVX_Vector *) r0_x_d; | ||
| HVX_Vector r1_d = *(const HVX_Vector *) r1_x_d; | ||
|
|
||
| vy_d = Q6_Vh_vshuff_Vh(vy_d); | ||
| HVX_Vector r01_d = Q6_V_vmux_QVV(rd_mask, r0_d, r1_d); | ||
|
|
||
| vy_d = Q6_Vh_vshuffe_VhVh(vy_d, vy_d); | ||
| r01_d = Q6_Vh_vshuff_Vh(r01_d); | ||
|
|
||
| HVX_VectorPair r01_dd = Q6_Wqf32_vmpy_VhfVhf(r01_d, vy_d); | ||
|
|
||
| HVX_Vector_x2 r; | ||
| r.v[0] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(r01_dd)); | ||
| r.v[1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(r01_dd)); | ||
| return r; | ||
| } | ||
|
|
||
| static inline HVX_Vector_x4 hvx_vec_load_and_mul_d_r2x2(const uint8_t * restrict r0_x_d, | ||
| const uint8_t * restrict r1_x_d, | ||
| const uint8_t * restrict y_d) { | ||
| HVX_Vector vy_d = *(const HVX_Vector *) y_d; | ||
| HVX_Vector r0_d = *(const HVX_Vector *) r0_x_d; | ||
| HVX_Vector r1_d = *(const HVX_Vector *) r1_x_d; | ||
|
|
||
| vy_d = Q6_Vh_vshuff_Vh(vy_d); | ||
| r0_d = Q6_Vh_vshuff_Vh(r0_d); | ||
| r1_d = Q6_Vh_vshuff_Vh(r1_d); | ||
|
|
||
| HVX_VectorPair r0_dd = Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d); | ||
| HVX_VectorPair r1_dd = Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d); | ||
|
|
||
| HVX_Vector_x4 r; | ||
| r.v[0] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(r0_dd)); | ||
| r.v[1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(r0_dd)); | ||
| r.v[2] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(r1_dd)); | ||
| r.v[3] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(r1_dd)); | ||
| return r; | ||
| } | ||
|
|
||
| static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { | ||
| assert(n % 32 == 0); // min sub-block size | ||
| assert((unsigned long) vx % 128 == 0); | ||
|
|
@@ -393,11 +437,11 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, | |
|
|
||
| const uint32_t qk = QK_Q4_0x4x2 * 4; | ||
|
|
||
| const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 | ||
| const uint32_t x_dblk_size = 8 * 4 * sizeof(uint16_t); // 32x __fp16 | ||
| const uint32_t x_qblk_size = qk / 2; // int4 | ||
| const uint32_t x_qrow_size = n / 2; // int4 (not padded) | ||
|
|
||
| const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 | ||
| const uint32_t y_dblk_size = 8 * 4 * sizeof(uint16_t); // 32x __fp16 | ||
| const uint32_t y_qblk_size = qk; // int8 | ||
| const uint32_t y_qrow_size = n; // int8 (not padded) | ||
|
|
||
|
|
@@ -422,26 +466,65 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, | |
| const uint32_t nloe = n % qk; // num leftover elemements | ||
|
|
||
| uint32_t i = 0; | ||
| for (; i < nb; i++) { | ||
| for (; i + 1 < nb; i += 2) { | ||
| HVX_Vector r00_ia; | ||
| HVX_Vector r10_ia; | ||
| { | ||
| HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); | ||
| HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); | ||
| HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); | ||
|
|
||
| r00_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); | ||
| r10_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); | ||
| } | ||
|
|
||
| HVX_Vector r01_ia; | ||
| HVX_Vector r11_ia; | ||
| { | ||
| HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + (i + 1) * y_qblk_size); | ||
| HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + (i + 1) * x_qblk_size); | ||
| HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + (i + 1) * x_qblk_size); | ||
|
|
||
| r01_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); | ||
| r11_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); | ||
| } | ||
|
|
||
| HVX_Vector_x4 r_dd = | ||
| hvx_vec_load_and_mul_d_r2x2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size, y_d + i * y_dblk_size); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optimized the scale multiplication step. The previous implementation only processed 32x
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm getting garbled output for all models.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| HVX_Vector r00_fa = Q6_Vqf32_vmpy_VsfVsf(r00_ia, r_dd.v[0]); | ||
| HVX_Vector r01_fa = Q6_Vqf32_vmpy_VsfVsf(r01_ia, r_dd.v[1]); | ||
|
|
||
| HVX_Vector r10_fa = Q6_Vqf32_vmpy_VsfVsf(r10_ia, r_dd.v[2]); | ||
| HVX_Vector r11_fa = Q6_Vqf32_vmpy_VsfVsf(r11_ia, r_dd.v[3]); | ||
|
|
||
| r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r00_fa); | ||
| r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r10_fa); | ||
|
|
||
| r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r01_fa); | ||
| r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r11_fa); | ||
| } | ||
|
|
||
| const HVX_VectorPred rd_mask = Q6_Q_vsetq_R(VLEN / 2); | ||
| r1_x_d -= VLEN / 2; // make sure r1 at the high half of the vector | ||
|
|
||
| if (i < nb) { | ||
| HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); | ||
| HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); | ||
| HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); | ||
|
|
||
| HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); | ||
| HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); | ||
|
|
||
| HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); | ||
| HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); | ||
| HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); | ||
| HVX_Vector_x2 r_dd = hvx_vec_load_and_mul_d_rx2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size, | ||
| y_d + i * y_dblk_size, rd_mask); | ||
|
|
||
| HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); | ||
| HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); | ||
|
|
||
| HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); | ||
| HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); | ||
| HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r_dd.v[0]); | ||
| HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r_dd.v[1]); | ||
|
|
||
| r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); | ||
| r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); | ||
| i++; | ||
| } | ||
|
|
||
| // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks | ||
|
|
@@ -453,17 +536,13 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, | |
| HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); | ||
| HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); | ||
|
|
||
| HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); | ||
| HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); | ||
| HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); | ||
|
|
||
| HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); | ||
| HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); | ||
| HVX_Vector_x2 r_dd = hvx_vec_load_and_mul_d_rx2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size, | ||
| y_d + i * y_dblk_size, rd_mask); | ||
|
|
||
| // Zero out unused scales | ||
| HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); | ||
| r0_dd = Q6_V_vand_QV(bmask, r0_dd); | ||
| r1_dd = Q6_V_vand_QV(bmask, r1_dd); | ||
| HVX_Vector r0_dd = Q6_V_vand_QV(bmask, r_dd.v[0]); | ||
| HVX_Vector r1_dd = Q6_V_vand_QV(bmask, r_dd.v[1]); | ||
|
|
||
| HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); | ||
| HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); | ||
|
|
@@ -473,8 +552,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, | |
| } | ||
|
|
||
| // Convert into fp32 and reduce | ||
| r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); | ||
| r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); | ||
| r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); | ||
| r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); | ||
| HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); | ||
|
|
||
| hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); | ||
|
|
@@ -594,6 +673,9 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, | |
| const uint32_t nb = n / qk; // num full blocks | ||
| int32_t nloe = n % qk; // num leftover elemements (must be signed) | ||
|
|
||
| const HVX_VectorPred rd_mask = Q6_Q_vsetq_R(VLEN / 2); | ||
| r1_x_d -= VLEN / 2; // make sure r1 at the high half of the vector | ||
|
|
||
| uint32_t i = 0; | ||
| for (; i < nb; i++) { | ||
| HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); | ||
|
|
@@ -603,15 +685,11 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, | |
| HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); | ||
| HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); | ||
|
|
||
| HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); | ||
| HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); | ||
| HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); | ||
| HVX_Vector_x2 r_dd = hvx_vec_load_and_mul_d_rx2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size, | ||
| y_d + i * y_dblk_size, rd_mask); | ||
|
|
||
| HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); | ||
| HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); | ||
|
|
||
| HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); | ||
| HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); | ||
| HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r_dd.v[0]); | ||
| HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r_dd.v[1]); | ||
|
|
||
| r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); | ||
| r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); | ||
|
|
@@ -626,17 +704,13 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, | |
| HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); | ||
| HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); | ||
|
|
||
| HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); | ||
| HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); | ||
| HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); | ||
|
|
||
| HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); | ||
| HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); | ||
| HVX_Vector_x2 r_dd = hvx_vec_load_and_mul_d_rx2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size, | ||
| y_d + i * y_dblk_size, rd_mask); | ||
|
|
||
| // Zero out unused scales | ||
| HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); | ||
| r0_dd = Q6_V_vand_QV(bmask, r0_dd); | ||
| r1_dd = Q6_V_vand_QV(bmask, r1_dd); | ||
| HVX_Vector r0_dd = Q6_V_vand_QV(bmask, r_dd.v[0]); | ||
| HVX_Vector r1_dd = Q6_V_vand_QV(bmask, r_dd.v[1]); | ||
|
|
||
| HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); | ||
| HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); | ||
|
|
@@ -646,8 +720,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, | |
| } | ||
|
|
||
| // Convert into fp32 and reduce | ||
| r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); | ||
| r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); | ||
| r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); | ||
| r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); | ||
| HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); | ||
|
|
||
| hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); | ||
|
|
@@ -684,8 +758,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, | |
| // Compute combined scale (fp32). | ||
| // Apply scale to acc and accumulate into the row sum (qf32). | ||
|
|
||
| const uint32_t nb = n / qk; // num full blocks | ||
| int32_t nloe = n % qk; // num leftover elemements (must be signed) | ||
| const uint32_t nb = n / qk; // num full blocks | ||
| int32_t nloe = n % qk; // num leftover elemements (must be signed) | ||
|
|
||
| const HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 | ||
| const HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; | ||
| const HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); | ||
|
|
||
| uint32_t i = 0; | ||
| for (; i < nb; i++) { | ||
|
|
@@ -698,19 +776,16 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, | |
| HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); | ||
|
|
||
| // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving | ||
| HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 | ||
| vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); | ||
| vy_d = Q6_Vsf_equals_Vqf32(vy_d); | ||
| vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); | ||
| vy_d = Q6_Vsf_equals_Vqf32(vy_d); | ||
|
|
||
| // Convert rX_d scales from e8m0 to fp32 | ||
| // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... | ||
| // Left shift with zero fill to create FP32 | ||
| // FIXME: might need to handle zero as a special case (see ggml-cpu code) | ||
| HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; | ||
| HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); | ||
| r0_d = Q6_V_vdelta_VV(r0_d, expand); | ||
| r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); | ||
| r0_d = Q6_Vw_vasl_VwR(r0_d, 23); | ||
| r0_d = Q6_V_vdelta_VV(r0_d, expand); | ||
| r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); | ||
| r0_d = Q6_Vw_vasl_VwR(r0_d, 23); | ||
|
|
||
| HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); | ||
|
|
||
|
|
@@ -738,11 +813,9 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, | |
| // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... | ||
| // Left shift with zero fill to create FP32 | ||
| // FIXME: might need to handle zero as a special case (see ggml-cpu code) | ||
| HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; | ||
| HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); | ||
| r0_d = Q6_V_vdelta_VV(r0_d, expand); | ||
| r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); | ||
| r0_d = Q6_Vw_vasl_VwR(r0_d, 23); | ||
| r0_d = Q6_V_vdelta_VV(r0_d, expand); | ||
| r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); | ||
| r0_d = Q6_Vw_vasl_VwR(r0_d, 23); | ||
|
|
||
| HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); | ||
|
|
||
|
|
@@ -888,8 +961,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, | |
| } | ||
|
|
||
| // Convert into fp32 and reduce | ||
| r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); | ||
| r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); | ||
| r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); | ||
| r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); | ||
| HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); | ||
|
|
||
| hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); | ||
|
|
@@ -917,14 +990,15 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri | |
|
|
||
| // for some reason we need volatile here so that the compiler doesn't try anything funky | ||
| volatile HVX_Vector rsum = Q6_V_vsplat_R(0); | ||
| const HVX_Vector kOne = Q6_Vh_vsplat_R(0x3C00); // 1.0 in fp16 | ||
|
|
||
| uint32_t i = 0; | ||
|
|
||
| for (i = 0; i < nv0; i++) { | ||
| HVX_VectorPair yp = vy[i]; | ||
|
|
||
| HVX_Vector x = vx[i]; | ||
| HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 | ||
| HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), kOne); // mul by 1.0 | ||
|
|
||
| HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); | ||
| HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); | ||
|
|
@@ -937,7 +1011,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri | |
| HVX_VectorPair yp = vy[i]; | ||
|
|
||
| HVX_Vector x = vx[i]; | ||
| HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 | ||
| HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), kOne); // mul by 1.0 | ||
|
|
||
| if (nv1 >= 32) { | ||
| HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); | ||
|
|
@@ -2199,7 +2273,7 @@ int op_matmul_id(struct htp_ops_context * octx) { | |
|
|
||
| assert(i02 >= 0 && i02 < n_as); | ||
|
|
||
| MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; | ||
| MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping){ id, iid1 }; | ||
| matrix_row_counts[i02] += 1; | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Given the update to 64x
f16, is it safe to assume thatx_dandy_dare now aligned toHVX_Vector?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. The current REPACK format is all quants (4-bit nibbles) followed by the scales.
For models where the number of elements per row is a multiple of 128 the scales will be aligned but for something like gpt-oss-20b with 2880 element rows the scales will not be aligned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverting to unaligned scale loading.
Thought, since the scales currently follow the quantization layout, it may be worth implementing an aligned load path for specific row elem cnt for better performance.