1515
1616static int _ccv_nnc_rmsnorm_forw (const ccv_nnc_cmd_t cmd , const ccv_nnc_hint_t hint , const int flags , ccv_nnc_tensor_t * const * const inputs , const int input_size , ccv_nnc_tensor_t * const * const outputs , const int output_size , ccv_nnc_stream_context_t * const stream_context )
1717{
18- assert (input_size == 2 );
18+ assert (input_size == 2 || input_size == 1 );
1919 ccv_nnc_tensor_view_t * const a = (ccv_nnc_tensor_view_t * )inputs [0 ];
20- ccv_nnc_tensor_view_t * const scale = (ccv_nnc_tensor_view_t * )inputs [1 ];
20+ ccv_nnc_tensor_view_t * const scale = input_size >= 2 ? (ccv_nnc_tensor_view_t * )inputs [1 ] : 0 ;
2121 ccv_nnc_tensor_view_t * const b = (ccv_nnc_tensor_view_t * )outputs [0 ];
2222 ccv_nnc_tensor_view_t * const saved_inv_std = (ccv_nnc_tensor_view_t * )outputs [1 ];
2323 assert (ccv_nnc_tensor_nd (a -> info .dim ) <= CCV_NNC_MAX_DIM + 2 );
@@ -33,7 +33,8 @@ static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
3333 int bstride [CCV_NNC_MAX_DIM_ALLOC ];
3434 int scale_stride [CCV_NNC_MAX_DIM_ALLOC ];
3535 ccv_nnc_tensor_view_get_stride (a , astride );
36- ccv_nnc_tensor_view_get_stride (scale , scale_stride );
36+ if (scale )
37+ ccv_nnc_tensor_view_get_stride (scale , scale_stride );
3738 ccv_nnc_tensor_view_get_stride (b , bstride );
3839 // The epsilon is used a little bit differently from batch norm, it is outside of the sqrt in this case.
3940 const float epsilon = cmd .info .rmsnorm .epsilon ;
@@ -91,36 +92,66 @@ static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
9192 }
9293 }
9394 }
94- float * const scalep = scale -> data .f32 ;
95- int sdim [CCV_NNC_MAX_DIM_ALLOC ];
96- ccv_nnc_tensor_view_get_dim (scale , sdim );
97- // Do the straight-forward one, y = x * inv_std * scale + bias, we cannot allocate extra memory to help.
98- // There is no need for precompute since scale / bias is per element.
99- float * const bp = b -> data .f32 ;
100- for (i [0 ] = 0 ; i [0 ] < adim [0 ]; i [0 ]++ )
95+ if (cmd .info .rmsnorm .elementwise_affine )
10196 {
102- float * const ap0 = ap + i [0 ] * astride [0 ];
103- float * const bp0 = bp + i [0 ] * bstride [0 ];
104- float * const varp0 = rdim [0 ] == 1 ? varp : varp + i [0 ] * saved_inv_std_stride [0 ];
105- float * const scalep0 = sdim [0 ] == 1 ? scalep : scalep + i [0 ] * scale_stride [0 ];
106- for (i [1 ] = 0 ; i [1 ] < adim [1 ]; i [1 ]++ )
97+ float * const scalep = scale -> data .f32 ;
98+ int sdim [CCV_NNC_MAX_DIM_ALLOC ];
99+ ccv_nnc_tensor_view_get_dim (scale , sdim );
100+ // Do the straight-forward one, y = x * inv_std * scale + bias, we cannot allocate extra memory to help.
101+ // There is no need for precompute since scale / bias is per element.
102+ float * const bp = b -> data .f32 ;
103+ for (i [0 ] = 0 ; i [0 ] < adim [0 ]; i [0 ]++ )
107104 {
108- float * ap1 = ap0 + i [1 ] * astride [1 ];
109- float * bp1 = bp0 + i [1 ] * bstride [1 ];
110- float * const varp1 = rdim [1 ] == 1 ? varp0 : varp0 + i [1 ] * saved_inv_std_stride [1 ];
111- float * const scalep1 = sdim [1 ] == 1 ? scalep0 : scalep0 + i [1 ] * scale_stride [1 ];
112- for (i [2 ] = 0 ; i [2 ] < adim [2 ]; i [2 ]++ )
105+ float * const ap0 = ap + i [0 ] * astride [0 ];
106+ float * const bp0 = bp + i [0 ] * bstride [0 ];
107+ float * const varp0 = rdim [0 ] == 1 ? varp : varp + i [0 ] * saved_inv_std_stride [0 ];
108+ float * const scalep0 = sdim [0 ] == 1 ? scalep : scalep + i [0 ] * scale_stride [0 ];
109+ for (i [1 ] = 0 ; i [1 ] < adim [1 ]; i [1 ]++ )
113110 {
114- float * const varp2 = rdim [2 ] == 1 ? varp1 : varp1 + i [2 ] * saved_inv_std_stride [2 ];
115- float * const scalep2 = sdim [2 ] == 1 ? scalep1 : scalep1 + i [2 ] * scale_stride [2 ];
116- if (rdim [3 ] == 1 )
117- for (x = 0 ; x < adim [3 ]; x ++ )
118- bp1 [x ] = ap1 [x * astride [3 ]] * varp2 [0 ] * scalep2 [sdim [3 ] == 1 ? 0 : x ];
119- else
120- for (x = 0 ; x < adim [3 ]; x ++ )
121- bp1 [x ] = ap1 [x * astride [3 ]] * varp2 [x ] * scalep2 [sdim [3 ] == 1 ? 0 : x ];
122- ap1 += astride [2 ];
123- bp1 += bstride [2 ];
111+ float * ap1 = ap0 + i [1 ] * astride [1 ];
112+ float * bp1 = bp0 + i [1 ] * bstride [1 ];
113+ float * const varp1 = rdim [1 ] == 1 ? varp0 : varp0 + i [1 ] * saved_inv_std_stride [1 ];
114+ float * const scalep1 = sdim [1 ] == 1 ? scalep0 : scalep0 + i [1 ] * scale_stride [1 ];
115+ for (i [2 ] = 0 ; i [2 ] < adim [2 ]; i [2 ]++ )
116+ {
117+ float * const varp2 = rdim [2 ] == 1 ? varp1 : varp1 + i [2 ] * saved_inv_std_stride [2 ];
118+ float * const scalep2 = sdim [2 ] == 1 ? scalep1 : scalep1 + i [2 ] * scale_stride [2 ];
119+ if (rdim [3 ] == 1 )
120+ for (x = 0 ; x < adim [3 ]; x ++ )
121+ bp1 [x ] = ap1 [x * astride [3 ]] * varp2 [0 ] * scalep2 [sdim [3 ] == 1 ? 0 : x ];
122+ else
123+ for (x = 0 ; x < adim [3 ]; x ++ )
124+ bp1 [x ] = ap1 [x * astride [3 ]] * varp2 [x ] * scalep2 [sdim [3 ] == 1 ? 0 : x ];
125+ ap1 += astride [2 ];
126+ bp1 += bstride [2 ];
127+ }
128+ }
129+ }
130+ } else {
131+ // Do the straight-forward one, y = x * inv_std, we cannot allocate extra memory to help.
132+ float * const bp = b -> data .f32 ;
133+ for (i [0 ] = 0 ; i [0 ] < adim [0 ]; i [0 ]++ )
134+ {
135+ float * const ap0 = ap + i [0 ] * astride [0 ];
136+ float * const bp0 = bp + i [0 ] * bstride [0 ];
137+ float * const varp0 = rdim [0 ] == 1 ? varp : varp + i [0 ] * saved_inv_std_stride [0 ];
138+ for (i [1 ] = 0 ; i [1 ] < adim [1 ]; i [1 ]++ )
139+ {
140+ float * ap1 = ap0 + i [1 ] * astride [1 ];
141+ float * bp1 = bp0 + i [1 ] * bstride [1 ];
142+ float * const varp1 = rdim [1 ] == 1 ? varp0 : varp0 + i [1 ] * saved_inv_std_stride [1 ];
143+ for (i [2 ] = 0 ; i [2 ] < adim [2 ]; i [2 ]++ )
144+ {
145+ float * const varp2 = rdim [2 ] == 1 ? varp1 : varp1 + i [2 ] * saved_inv_std_stride [2 ];
146+ if (rdim [3 ] == 1 )
147+ for (x = 0 ; x < adim [3 ]; x ++ )
148+ bp1 [x ] = ap1 [x * astride [3 ]] * varp2 [0 ];
149+ else
150+ for (x = 0 ; x < adim [3 ]; x ++ )
151+ bp1 [x ] = ap1 [x * astride [3 ]] * varp2 [x ];
152+ ap1 += astride [2 ];
153+ bp1 += bstride [2 ];
154+ }
124155 }
125156 }
126157 }
@@ -129,12 +160,13 @@ static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
129160
130161static int _ccv_nnc_rmsnorm_back (const ccv_nnc_cmd_t cmd , const ccv_nnc_hint_t hint , const int flags , ccv_nnc_tensor_t * const * const inputs , const int input_size , ccv_nnc_tensor_t * const * const outputs , const int output_size , ccv_nnc_stream_context_t * const stream_context )
131162{
132- assert (input_size == 6 );
163+ assert (input_size == 6 || input_size == 5 );
133164 assert (output_size >= 1 );
165+ const int elementwise_affine = cmd .info .rmsnorm .elementwise_affine ;
134166 ccv_nnc_tensor_view_t * const g = (ccv_nnc_tensor_view_t * )inputs [0 ];
135167 ccv_nnc_tensor_view_t * const a = (ccv_nnc_tensor_view_t * )inputs [2 ];
136- ccv_nnc_tensor_view_t * const scale = (ccv_nnc_tensor_view_t * )inputs [3 ];
137- ccv_nnc_tensor_view_t * const saved_inv_std = (ccv_nnc_tensor_view_t * )inputs [5 ];
168+ ccv_nnc_tensor_view_t * const scale = elementwise_affine ? (ccv_nnc_tensor_view_t * )inputs [3 ] : 0 ;
169+ ccv_nnc_tensor_view_t * const saved_inv_std = (ccv_nnc_tensor_view_t * )inputs [elementwise_affine ? 5 : 4 ];
138170 ccv_nnc_tensor_view_t * const h = (ccv_nnc_tensor_view_t * )outputs [0 ];
139171 ccv_nnc_tensor_view_t * const dscale = output_size > 1 ? (ccv_nnc_tensor_view_t * )outputs [1 ] : 0 ;
140172 assert (ccv_nnc_tensor_nd (g -> info .dim ) <= CCV_NNC_MAX_DIM + 2 );
@@ -146,7 +178,8 @@ static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
146178 ccv_nnc_tensor_view_get_dim (g , gdim );
147179 ccv_nnc_tensor_view_get_dim (saved_inv_std , rdim );
148180 int sdim [CCV_NNC_MAX_DIM_ALLOC ];
149- ccv_nnc_tensor_view_get_dim (scale , sdim );
181+ if (scale )
182+ ccv_nnc_tensor_view_get_dim (scale , sdim );
150183 if (dscale )
151184 { assert (ccv_nnc_tensor_view_check_dim (dscale , sdim )); }
152185 assert (ccv_nnc_tensor_view_check_dim (a , gdim ));
@@ -160,7 +193,8 @@ static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
160193 ccv_nnc_tensor_view_get_stride (a , astride );
161194 ccv_nnc_tensor_view_get_stride (g , gstride );
162195 ccv_nnc_tensor_view_get_stride (h , hstride );
163- ccv_nnc_tensor_view_get_stride (scale , scale_stride );
196+ if (scale )
197+ ccv_nnc_tensor_view_get_stride (scale , scale_stride );
164198 ccv_nnc_tensor_view_get_stride (saved_inv_std , inv_std_stride );
165199 if (dscale )
166200 ccv_nnc_tensor_view_get_stride (dscale , dscale_stride );
@@ -252,29 +286,55 @@ static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
252286 } else {
253287 float * gssp = gss ;
254288 const float * const gp = g -> data .f32 ;
255- const float * const scalep = scale -> data .f32 ;
256- for (i [0 ] = 0 ; i [0 ] < gdim [0 ]; i [0 ]++ )
289+ if (elementwise_affine )
257290 {
258- const float * const gp0 = gp + i [0 ] * gstride [0 ];
259- const float * const inv_stdp0 = rdim [0 ] == 1 ? inv_stdp : inv_stdp + i [0 ] * inv_std_stride [0 ];
260- const float * const scalep0 = sdim [0 ] == 1 ? scalep : scalep + i [0 ] * scale_stride [0 ];
261- for (i [1 ] = 0 ; i [1 ] < gdim [1 ]; i [1 ]++ )
291+ const float * const scalep = scale -> data .f32 ;
292+ for (i [0 ] = 0 ; i [0 ] < gdim [0 ]; i [0 ]++ )
262293 {
263- const float * gp1 = gp0 + i [1 ] * gstride [1 ];
264- const float * const inv_stdp1 = rdim [1 ] == 1 ? inv_stdp0 : inv_stdp0 + i [1 ] * inv_std_stride [1 ];
265- const float * const scalep1 = sdim [1 ] == 1 ? scalep0 : scalep0 + i [1 ] * scale_stride [1 ];
266- for (i [2 ] = 0 ; i [2 ] < gdim [2 ]; i [2 ]++ )
294+ const float * const gp0 = gp + i [0 ] * gstride [0 ];
295+ const float * const inv_stdp0 = rdim [0 ] == 1 ? inv_stdp : inv_stdp + i [0 ] * inv_std_stride [0 ];
296+ const float * const scalep0 = sdim [0 ] == 1 ? scalep : scalep + i [0 ] * scale_stride [0 ];
297+ for (i [1 ] = 0 ; i [1 ] < gdim [1 ]; i [1 ]++ )
267298 {
268- const float * const inv_stdp2 = rdim [2 ] == 1 ? inv_stdp1 : inv_stdp1 + i [2 ] * inv_std_stride [2 ];
269- const float * const scalep2 = sdim [2 ] == 1 ? scalep1 : scalep1 + i [2 ] * scale_stride [2 ];
270- if (sdim [3 ] == 1 )
271- for (x = 0 ; x < gdim [3 ]; x ++ )
272- gssp [x ] = gp1 [x ] * scalep2 [0 ] * inv_stdp2 [rdim [3 ] == 1 ? 0 : x ];
273- else
274- for (x = 0 ; x < gdim [3 ]; x ++ )
275- gssp [x ] = gp1 [x ] * scalep2 [x ] * inv_stdp2 [rdim [3 ] == 1 ? 0 : x ];
276- gp1 += gstride [2 ];
277- gssp += gdim [3 ];
299+ const float * gp1 = gp0 + i [1 ] * gstride [1 ];
300+ const float * const inv_stdp1 = rdim [1 ] == 1 ? inv_stdp0 : inv_stdp0 + i [1 ] * inv_std_stride [1 ];
301+ const float * const scalep1 = sdim [1 ] == 1 ? scalep0 : scalep0 + i [1 ] * scale_stride [1 ];
302+ for (i [2 ] = 0 ; i [2 ] < gdim [2 ]; i [2 ]++ )
303+ {
304+ const float * const inv_stdp2 = rdim [2 ] == 1 ? inv_stdp1 : inv_stdp1 + i [2 ] * inv_std_stride [2 ];
305+ const float * const scalep2 = sdim [2 ] == 1 ? scalep1 : scalep1 + i [2 ] * scale_stride [2 ];
306+ if (sdim [3 ] == 1 )
307+ for (x = 0 ; x < gdim [3 ]; x ++ )
308+ gssp [x ] = gp1 [x ] * scalep2 [0 ] * inv_stdp2 [rdim [3 ] == 1 ? 0 : x ];
309+ else
310+ for (x = 0 ; x < gdim [3 ]; x ++ )
311+ gssp [x ] = gp1 [x ] * scalep2 [x ] * inv_stdp2 [rdim [3 ] == 1 ? 0 : x ];
312+ gp1 += gstride [2 ];
313+ gssp += gdim [3 ];
314+ }
315+ }
316+ }
317+ } else {
318+ for (i [0 ] = 0 ; i [0 ] < gdim [0 ]; i [0 ]++ )
319+ {
320+ const float * const gp0 = gp + i [0 ] * gstride [0 ];
321+ const float * const inv_stdp0 = rdim [0 ] == 1 ? inv_stdp : inv_stdp + i [0 ] * inv_std_stride [0 ];
322+ for (i [1 ] = 0 ; i [1 ] < gdim [1 ]; i [1 ]++ )
323+ {
324+ const float * gp1 = gp0 + i [1 ] * gstride [1 ];
325+ const float * const inv_stdp1 = rdim [1 ] == 1 ? inv_stdp0 : inv_stdp0 + i [1 ] * inv_std_stride [1 ];
326+ for (i [2 ] = 0 ; i [2 ] < gdim [2 ]; i [2 ]++ )
327+ {
328+ const float * const inv_stdp2 = rdim [2 ] == 1 ? inv_stdp1 : inv_stdp1 + i [2 ] * inv_std_stride [2 ];
329+ if (rdim [3 ] == 1 )
330+ for (x = 0 ; x < gdim [3 ]; x ++ )
331+ gssp [x ] = gp1 [x ] * inv_stdp2 [0 ];
332+ else
333+ for (x = 0 ; x < gdim [3 ]; x ++ )
334+ gssp [x ] = gp1 [x ] * inv_stdp2 [x ];
335+ gp1 += gstride [2 ];
336+ gssp += gdim [3 ];
337+ }
278338 }
279339 }
280340 }
0 commit comments