@@ -2658,35 +2658,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
26582658 const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
26592659 const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
26602660
2661+ // interleave
2662+ const int8x16_t v0_0lz = vzip1q_s8 (v0_0ls , v0_0hs );
2663+ const int8x16_t v0_0hz = vzip2q_s8 (v0_0ls , v0_0hs );
2664+ const int8x16_t v0_1lz = vzip1q_s8 (v0_1ls , v0_1hs );
2665+ const int8x16_t v0_1hz = vzip2q_s8 (v0_1ls , v0_1hs );
2666+
26612667 // load y
26622668 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
26632669 const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
26642670 const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
26652671 const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
26662672
2667- // interleave
2668- const int8x16_t v1_0ls = vuzp1q_s8 (v1_0l , v1_0h );
2669- const int8x16_t v1_0hs = vuzp2q_s8 (v1_0l , v1_0h );
2670- const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2671- const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2672-
26732673#if defined(__ARM_FEATURE_DOTPROD )
26742674 // dot product into int32x4_t
2675- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2676- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2675+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l ), v0_0hz , v1_0h );
2676+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l ), v0_1hz , v1_1h );
26772677
26782678 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
26792679 sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
26802680#else
2681- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2682- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
2683- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
2684- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
2685-
2686- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
2687- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
2688- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2689- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2681+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
2682+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
2683+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
2684+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
2685+
2686+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
2687+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
2688+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
2689+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
26902690
26912691 const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
26922692 const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
0 commit comments