@@ -9018,18 +9018,20 @@ static void ggml_compute_forward_rms_norm_f32(
90189018 GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
90199019
90209020 if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
9021+ atomic_store (params -> aic , 0 );
9022+
90219023 return ;
90229024 }
90239025
90249026 GGML_ASSERT (src0 -> nb [0 ] == sizeof (float ));
90259027
9026- const int ith = params -> ith ;
9028+ const int ith = params -> ith ; UNUSED ( ith );
90279029 const int nth = params -> nth ;
90289030
90299031 const int64_t ne00 = src0 -> ne [0 ];
90309032 const int64_t ne01 = src0 -> ne [1 ];
90319033 const int64_t ne02 = src0 -> ne [2 ];
9032- const int64_t ne03 = src0 -> ne [3 ];
9034+ const int64_t ne03 = src0 -> ne [3 ]; UNUSED ( ne03 );
90339035
90349036 const size_t nb01 = src0 -> nb [1 ];
90359037 const size_t nb02 = src0 -> nb [2 ];
@@ -9041,30 +9043,45 @@ static void ggml_compute_forward_rms_norm_f32(
90419043
90429044 const float eps = 1e-6f ; // TODO: make this a parameter
90439045
9044- // TODO: optimize
9045- for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
9046- for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
9047- for (int64_t i01 = ith ; i01 < ne01 ; i01 += nth ) {
9048- const float * x = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
9049-
9050- ggml_float sum = 0.0 ;
9051- for (int64_t i00 = 0 ; i00 < ne00 ; i00 ++ ) {
9052- sum += (ggml_float )(x [i00 ] * x [i00 ]);
9053- }
9046+ const int nr = ggml_nrows (src0 );
9047+ const int dr = (nr + 8 * nth - 1 )/(8 * nth );
90549048
9055- float mean = sum /ne00 ;
9049+ while (true) {
9050+ const int ir0 = atomic_fetch_add (params -> aic , dr );
90569051
9057- float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
9052+ for (int ir = ir0 ; ir < ir0 + dr ; ++ ir ) {
9053+ if (ir >= nr ) {
9054+ break ;
9055+ }
90589056
9059- memcpy ( y , x , ne00 * sizeof ( float ));
9060- // for ( int i00 = 0; i00 < ne00; i00++) {
9061- // y[i00] = x[i00] ;
9062- // }
9057+ // src0 indices
9058+ const int i03 = ir /( ne02 * ne01 );
9059+ const int i02 = ( ir - i03 * ne02 * ne01 )/ ne01 ;
9060+ const int i01 = ( ir - i03 * ne02 * ne01 - i02 * ne01 );
90639061
9064- const float scale = 1.0f / sqrtf ( mean + eps );
9062+ const float * x = ( float * ) (( char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
90659063
9066- ggml_vec_scale_f32 (ne00 , y , scale );
9064+ ggml_float sum = 0.0 ;
9065+ for (int64_t i00 = 0 ; i00 < ne00 ; i00 ++ ) {
9066+ sum += (ggml_float )(x [i00 ] * x [i00 ]);
90679067 }
9068+
9069+ float mean = sum /ne00 ;
9070+
9071+ float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
9072+
9073+ memcpy (y , x , ne00 * sizeof (float ));
9074+ // for (int i00 = 0; i00 < ne00; i00++) {
9075+ // y[i00] = x[i00];
9076+ // }
9077+
9078+ const float scale = 1.0f /sqrtf (mean + eps );
9079+
9080+ ggml_vec_scale_f32 (ne00 , y , scale );
9081+ }
9082+
9083+ if (ir0 + dr >= nr ) {
9084+ break ;
90689085 }
90699086 }
90709087}
@@ -9739,11 +9756,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
97399756 const int nb2 = dst -> nb [2 ];
97409757 const int nb3 = dst -> nb [3 ];
97419758
9742- const int ith = params -> ith ;
9759+ const int ith = params -> ith ; UNUSED ( ith );
97439760 const int nth = params -> nth ;
97449761
9745- UNUSED (ith );
9746-
97479762 GGML_ASSERT (ne02 == ne12 );
97489763 GGML_ASSERT (ne03 == ne13 );
97499764 GGML_ASSERT (ne2 == ne12 );
0 commit comments