Skip to content

Commit ebd7f28

Browse files
committed
Add elementwise_affine support for rmsnorm.
1 parent 1e204eb commit ebd7f28

File tree

9 files changed

+523
-100
lines changed

9 files changed

+523
-100
lines changed

lib/nnc/ccv_cnnp_model_addons.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2555,7 +2555,7 @@ static const ccv_cnnp_model_vtab_t ccv_cnnp_rmsnorm_isa = {
25552555
.copy = _ccv_cnnp_rmsnorm_copy,
25562556
};
25572557

2558-
ccv_cnnp_model_t* ccv_cnnp_rmsnorm(const float epsilon, const int axis[CCV_NNC_MAX_DIM_ALLOC], const int axis_count, const int is_trainable, const char* const name)
2558+
ccv_cnnp_model_t* ccv_cnnp_rmsnorm(const float epsilon, const int axis[CCV_NNC_MAX_DIM_ALLOC], const int axis_count, const int elementwise_affine, const int is_trainable, const char* const name)
25592559
{
25602560
ccv_cnnp_model_rmsnorm_t* const model_rmsnorm = (ccv_cnnp_model_rmsnorm_t*)cccalloc(1, sizeof(ccv_cnnp_model_rmsnorm_t));
25612561
model_rmsnorm->super.isa = &ccv_cnnp_rmsnorm_isa;
@@ -2568,14 +2568,15 @@ ccv_cnnp_model_t* ccv_cnnp_rmsnorm(const float epsilon, const int axis[CCV_NNC_M
25682568
model_rmsnorm->scale.graph = 0;
25692569
model_rmsnorm->params.rmsnorm.epsilon = epsilon;
25702570
model_rmsnorm->params.rmsnorm.count = axis_count;
2571+
model_rmsnorm->params.rmsnorm.elementwise_affine = elementwise_affine;
25712572
memcpy(model_rmsnorm->params.lnorm.axis, axis, sizeof(int) * axis_count);
25722573
return (ccv_cnnp_model_t*)model_rmsnorm;
25732574
}
25742575

25752576
static ccv_cnnp_model_t* _ccv_cnnp_rmsnorm_copy(const ccv_cnnp_model_t* const super, void* const context)
25762577
{
25772578
const ccv_cnnp_model_rmsnorm_t* const self = (const ccv_cnnp_model_rmsnorm_t*)super;
2578-
return ccv_cnnp_rmsnorm(self->params.rmsnorm.epsilon, self->params.rmsnorm.axis, self->params.rmsnorm.count, self->super.is_trainable, self->super.name);
2579+
return ccv_cnnp_rmsnorm(self->params.rmsnorm.epsilon, self->params.rmsnorm.axis, self->params.rmsnorm.count, self->params.rmsnorm.elementwise_affine, self->super.is_trainable, self->super.name);
25792580
}
25802581

25812582
// MARK - Batched Matrix Mul Layer

lib/nnc/ccv_nnc.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ typedef struct {
170170
int axis[CCV_NNC_MAX_DIM_ALLOC]; /**< [rmsnorm.axis[]] The axis selected to compute mean / variance. */
171171
int count; /**< [rmsnorm.count] The number of axis selected. */
172172
float epsilon; /**< [rmsnorm.epsilon] The epsilon for standard derivation. */
173+
int elementwise_affine; /**< [rmsnorm.elementwise_affine] Whether it supports scale. */
173174
} rmsnorm;
174175
struct {
175176
int nesterov; /**< [sgd.nesterov] Nesterov accelerated gradient. */
@@ -4549,11 +4550,12 @@ CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_group_norm(const int group_axis, con
45494550
* @param epsilon The epsilon in layer norm parameter.
45504551
* @param axis The axis are the feature axis to compute norm.
45514552
* @param axis_count How many axis we count as feature.
4553+
* @param elementwise_affine Whether it contains scale.
45524554
* @param is_trainable Whether the parameters of this model can be trained.
45534555
* @param name The unique name of the model.
45544556
* @return A rmsnorm model.
45554557
*/
4556-
CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_rmsnorm(const float epsilon, const int axis[CCV_NNC_MAX_DIM_ALLOC], const int axis_count, const int is_trainable, const char* const name);
4558+
CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_rmsnorm(const float epsilon, const int axis[CCV_NNC_MAX_DIM_ALLOC], const int axis_count, const int elementwise_affine, const int is_trainable, const char* const name);
45574559
/**
45584560
* Add two input tensors together. Different from sum because this support broadcasting.
45594561
* @param p The weight for the first input.

lib/nnc/cmd/ccv_nnc_cmd_easy.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@
183183
// CCV_NNC_GROUP_NORM_BACKWARD
184184
#define CMD_GROUP_NORM_BACKWARD(_group_axis, _groups, _epsilon, _elementwise_affine, ...) ccv_nnc_cmd(CCV_NNC_GROUP_NORM_BACKWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.gnorm={.group_axis=_group_axis,.groups=_groups,.epsilon=_epsilon,.elementwise_affine=_elementwise_affine,.reduce_count=LIST_COUNT(__VA_ARGS__),.reduce_axis={__VA_ARGS__}}}), 0)
185185
// CCV_NNC_RMSNORM_FORWARD
186-
#define CMD_RMSNORM_FORWARD(_epsilon, ...) ccv_nnc_cmd(CCV_NNC_RMSNORM_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.rmsnorm={.epsilon=_epsilon,.count=LIST_COUNT(__VA_ARGS__),.axis={__VA_ARGS__}}}), 0)
186+
#define CMD_RMSNORM_FORWARD(_epsilon, _elementwise_affine, ...) ccv_nnc_cmd(CCV_NNC_RMSNORM_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.rmsnorm={.epsilon=_epsilon,.elementwise_affine=_elementwise_affine,.count=LIST_COUNT(__VA_ARGS__),.axis={__VA_ARGS__}}}), 0)
187187
// CCV_NNC_RMSNORM_BACKWARD
188-
#define CMD_RMSNORM_BACKWARD(_epsilon, ...) ccv_nnc_cmd(CCV_NNC_RMSNORM_BACKWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.rmsnorm={.epsilon=_epsilon,.count=LIST_COUNT(__VA_ARGS__),.axis={__VA_ARGS__}}}), 0)
188+
#define CMD_RMSNORM_BACKWARD(_epsilon, _elementwise_affine, ...) ccv_nnc_cmd(CCV_NNC_RMSNORM_BACKWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.rmsnorm={.epsilon=_epsilon,.elementwise_affine=_elementwise_affine,.count=LIST_COUNT(__VA_ARGS__),.axis={__VA_ARGS__}}}), 0)
189189
// CCV_NNC_PAD_FORWARD
190190
#define CMD_PAD_FORWARD(_type, _begin, _end) ccv_nnc_cmd(CCV_NNC_PAD_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={ESCAPE_X _begin}},.pad={.type=_type,.end={ESCAPE_X _end}}}), 0)
191191
// CCV_NNC_PAD_BACKWARD

lib/nnc/cmd/norm/ccv_nnc_norm.c

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -287,28 +287,45 @@ REGISTER_COMMAND(CCV_NNC_GROUP_NORM_BACKWARD)(ccv_nnc_cmd_registry_t* const regi
287287

288288
static int _ccv_nnc_rmsnorm_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
289289
{
290-
// 2 inputs (x, gamma)
291-
// 2 outputs (y, saved_inv_std)
292-
if (input_bitmasks[0] == 3u && output_bitmasks[0] == 3u)
293-
return 1;
290+
if (cmd.rmsnorm.elementwise_affine)
291+
{
292+
// 2 inputs (x, gamma)
293+
// 2 outputs (y, saved_inv_std)
294+
if (input_bitmasks[0] == 3u && output_bitmasks[0] == 3u)
295+
return 1;
296+
} else {
297+
// 1 inputs (x)
298+
// 2 outputs (y, saved_inv_std)
299+
if (input_bitmasks[0] == 1u && output_bitmasks[0] == 3u)
300+
return 1;
301+
}
294302
return 0;
295303
}
296304

297305
static int _ccv_nnc_rmsnorm_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
298306
{
299-
// 1 + 4 + 8 + 32
300-
// Inputs (gradient, 0, x, gamma, 0, saved_inv_std)
301-
// Output the propagated error, dgamma
302-
if ((input_bitmasks[0] & 45u) == 45u && (output_bitmasks[0] & 3u) == 3u)
303-
return 1;
304-
if ((input_bitmasks[0] & 45u) == 45u && (output_bitmasks[0] & 1u) == 1u)
305-
return 1;
307+
if (cmd.rmsnorm.elementwise_affine)
308+
{
309+
// 1 + 4 + 8 + 32
310+
// Inputs (gradient, 0, x, gamma, 0, saved_inv_std)
311+
// Output the propagated error, dgamma
312+
if ((input_bitmasks[0] & 45u) == 45u && (output_bitmasks[0] & 3u) == 3u)
313+
return 1;
314+
if ((input_bitmasks[0] & 45u) == 45u && (output_bitmasks[0] & 1u) == 1u)
315+
return 1;
316+
} else {
317+
// 1 + 4 + 16
318+
// Inputs (gradient, 0, x, 0, saved_inv_std)
319+
// Output the propagated error
320+
if ((input_bitmasks[0] & 21u) == 21u && (output_bitmasks[0] & 1u) == 1u)
321+
return 1;
322+
}
306323
return 0;
307324
}
308325

309326
static void _ccv_nnc_rmsnorm_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size)
310327
{
311-
assert(input_size == 2);
328+
assert(input_size == 2 || input_size == 1);
312329
assert(output_size == 1 || output_size == 2);
313330
outputs[0] = inputs[0];
314331
if (output_size == 1)
@@ -324,7 +341,7 @@ static void _ccv_nnc_rmsnorm_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, con
324341

325342
static void _ccv_nnc_rmsnorm_tensor_auto_back(const ccv_nnc_cmd_param_t cmd, const ccv_nnc_tensor_param_t* const inputs, const int input_size, const ccv_nnc_hint_t hint, ccv_nnc_tensor_param_t* const outputs, const int output_size)
326343
{
327-
assert(input_size == 6);
344+
assert(input_size == 6 || input_size == 5);
328345
assert(output_size == 1 || output_size == 2);
329346
outputs[0] = inputs[0];
330347
int i, j;
@@ -351,6 +368,6 @@ REGISTER_COMMAND(CCV_NNC_RMSNORM_BACKWARD)(ccv_nnc_cmd_registry_t* const registr
351368
}
352369

353370
//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_RMSNORM_FORWARD)
354-
#define CMD_RMSNORM_FORWARD(_epsilon, ...) ccv_nnc_cmd(CCV_NNC_RMSNORM_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.rmsnorm={.epsilon=_epsilon,.count=LIST_COUNT(__VA_ARGS__),.axis={__VA_ARGS__}}}), 0)
371+
#define CMD_RMSNORM_FORWARD(_epsilon, _elementwise_affine, ...) ccv_nnc_cmd(CCV_NNC_RMSNORM_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.rmsnorm={.epsilon=_epsilon,.elementwise_affine=_elementwise_affine,.count=LIST_COUNT(__VA_ARGS__),.axis={__VA_ARGS__}}}), 0)
355372
//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_RMSNORM_BACKWARD)
356-
#define CMD_RMSNORM_BACKWARD(_epsilon, ...) ccv_nnc_cmd(CCV_NNC_RMSNORM_BACKWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.rmsnorm={.epsilon=_epsilon,.count=LIST_COUNT(__VA_ARGS__),.axis={__VA_ARGS__}}}), 0)
373+
#define CMD_RMSNORM_BACKWARD(_epsilon, _elementwise_affine, ...) ccv_nnc_cmd(CCV_NNC_RMSNORM_BACKWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.rmsnorm={.epsilon=_epsilon,.elementwise_affine=_elementwise_affine,.count=LIST_COUNT(__VA_ARGS__),.axis={__VA_ARGS__}}}), 0)

lib/nnc/cmd/norm/ccv_nnc_rmsnorm_cpu_ref.c

Lines changed: 115 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
1717
{
18-
assert(input_size == 2);
18+
assert(input_size == 2 || input_size == 1);
1919
ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[0];
20-
ccv_nnc_tensor_view_t* const scale = (ccv_nnc_tensor_view_t*)inputs[1];
20+
ccv_nnc_tensor_view_t* const scale = input_size >= 2 ? (ccv_nnc_tensor_view_t*)inputs[1] : 0;
2121
ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0];
2222
ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)outputs[1];
2323
assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
@@ -33,7 +33,8 @@ static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
3333
int bstride[CCV_NNC_MAX_DIM_ALLOC];
3434
int scale_stride[CCV_NNC_MAX_DIM_ALLOC];
3535
ccv_nnc_tensor_view_get_stride(a, astride);
36-
ccv_nnc_tensor_view_get_stride(scale, scale_stride);
36+
if (scale)
37+
ccv_nnc_tensor_view_get_stride(scale, scale_stride);
3738
ccv_nnc_tensor_view_get_stride(b, bstride);
3839
// The epsilon is used a little bit differently from batch norm, it is outside of the sqrt in this case.
3940
const float epsilon = cmd.info.rmsnorm.epsilon;
@@ -91,36 +92,66 @@ static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
9192
}
9293
}
9394
}
94-
float* const scalep = scale->data.f32;
95-
int sdim[CCV_NNC_MAX_DIM_ALLOC];
96-
ccv_nnc_tensor_view_get_dim(scale, sdim);
97-
// Do the straight-forward one, y = x * inv_std * scale + bias, we cannot allocate extra memory to help.
98-
// There is no need for precompute since scale / bias is per element.
99-
float* const bp = b->data.f32;
100-
for (i[0] = 0; i[0] < adim[0]; i[0]++)
95+
if (cmd.info.rmsnorm.elementwise_affine)
10196
{
102-
float* const ap0 = ap + i[0] * astride[0];
103-
float* const bp0 = bp + i[0] * bstride[0];
104-
float* const varp0 = rdim[0] == 1 ? varp : varp + i[0] * saved_inv_std_stride[0];
105-
float* const scalep0 = sdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0];
106-
for (i[1] = 0; i[1] < adim[1]; i[1]++)
97+
float* const scalep = scale->data.f32;
98+
int sdim[CCV_NNC_MAX_DIM_ALLOC];
99+
ccv_nnc_tensor_view_get_dim(scale, sdim);
100+
// Do the straight-forward one, y = x * inv_std * scale + bias, we cannot allocate extra memory to help.
101+
// There is no need for precompute since scale / bias is per element.
102+
float* const bp = b->data.f32;
103+
for (i[0] = 0; i[0] < adim[0]; i[0]++)
107104
{
108-
float* ap1 = ap0 + i[1] * astride[1];
109-
float* bp1 = bp0 + i[1] * bstride[1];
110-
float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1];
111-
float* const scalep1 = sdim[1] == 1 ? scalep0 : scalep0 + i[1] * scale_stride[1];
112-
for (i[2] = 0; i[2] < adim[2]; i[2]++)
105+
float* const ap0 = ap + i[0] * astride[0];
106+
float* const bp0 = bp + i[0] * bstride[0];
107+
float* const varp0 = rdim[0] == 1 ? varp : varp + i[0] * saved_inv_std_stride[0];
108+
float* const scalep0 = sdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0];
109+
for (i[1] = 0; i[1] < adim[1]; i[1]++)
113110
{
114-
float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2];
115-
float* const scalep2 = sdim[2] == 1 ? scalep1 : scalep1 + i[2] * scale_stride[2];
116-
if (rdim[3] == 1)
117-
for (x = 0; x < adim[3]; x++)
118-
bp1[x] = ap1[x * astride[3]] * varp2[0] * scalep2[sdim[3] == 1 ? 0 : x];
119-
else
120-
for (x = 0; x < adim[3]; x++)
121-
bp1[x] = ap1[x * astride[3]] * varp2[x] * scalep2[sdim[3] == 1 ? 0 : x];
122-
ap1 += astride[2];
123-
bp1 += bstride[2];
111+
float* ap1 = ap0 + i[1] * astride[1];
112+
float* bp1 = bp0 + i[1] * bstride[1];
113+
float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1];
114+
float* const scalep1 = sdim[1] == 1 ? scalep0 : scalep0 + i[1] * scale_stride[1];
115+
for (i[2] = 0; i[2] < adim[2]; i[2]++)
116+
{
117+
float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2];
118+
float* const scalep2 = sdim[2] == 1 ? scalep1 : scalep1 + i[2] * scale_stride[2];
119+
if (rdim[3] == 1)
120+
for (x = 0; x < adim[3]; x++)
121+
bp1[x] = ap1[x * astride[3]] * varp2[0] * scalep2[sdim[3] == 1 ? 0 : x];
122+
else
123+
for (x = 0; x < adim[3]; x++)
124+
bp1[x] = ap1[x * astride[3]] * varp2[x] * scalep2[sdim[3] == 1 ? 0 : x];
125+
ap1 += astride[2];
126+
bp1 += bstride[2];
127+
}
128+
}
129+
}
130+
} else {
131+
// Do the straight-forward one, y = x * inv_std, we cannot allocate extra memory to help.
132+
float* const bp = b->data.f32;
133+
for (i[0] = 0; i[0] < adim[0]; i[0]++)
134+
{
135+
float* const ap0 = ap + i[0] * astride[0];
136+
float* const bp0 = bp + i[0] * bstride[0];
137+
float* const varp0 = rdim[0] == 1 ? varp : varp + i[0] * saved_inv_std_stride[0];
138+
for (i[1] = 0; i[1] < adim[1]; i[1]++)
139+
{
140+
float* ap1 = ap0 + i[1] * astride[1];
141+
float* bp1 = bp0 + i[1] * bstride[1];
142+
float* const varp1 = rdim[1] == 1 ? varp0 : varp0 + i[1] * saved_inv_std_stride[1];
143+
for (i[2] = 0; i[2] < adim[2]; i[2]++)
144+
{
145+
float* const varp2 = rdim[2] == 1 ? varp1 : varp1 + i[2] * saved_inv_std_stride[2];
146+
if (rdim[3] == 1)
147+
for (x = 0; x < adim[3]; x++)
148+
bp1[x] = ap1[x * astride[3]] * varp2[0];
149+
else
150+
for (x = 0; x < adim[3]; x++)
151+
bp1[x] = ap1[x * astride[3]] * varp2[x];
152+
ap1 += astride[2];
153+
bp1 += bstride[2];
154+
}
124155
}
125156
}
126157
}
@@ -129,12 +160,13 @@ static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
129160

130161
static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
131162
{
132-
assert(input_size == 6);
163+
assert(input_size == 6 || input_size == 5);
133164
assert(output_size >= 1);
165+
const int elementwise_affine = cmd.info.rmsnorm.elementwise_affine;
134166
ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
135167
ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[2];
136-
ccv_nnc_tensor_view_t* const scale = (ccv_nnc_tensor_view_t*)inputs[3];
137-
ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)inputs[5];
168+
ccv_nnc_tensor_view_t* const scale = elementwise_affine ? (ccv_nnc_tensor_view_t*)inputs[3] : 0;
169+
ccv_nnc_tensor_view_t* const saved_inv_std = (ccv_nnc_tensor_view_t*)inputs[elementwise_affine ? 5 : 4];
138170
ccv_nnc_tensor_view_t* const h = (ccv_nnc_tensor_view_t*)outputs[0];
139171
ccv_nnc_tensor_view_t* const dscale = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 0;
140172
assert(ccv_nnc_tensor_nd(g->info.dim) <= CCV_NNC_MAX_DIM + 2);
@@ -146,7 +178,8 @@ static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
146178
ccv_nnc_tensor_view_get_dim(g, gdim);
147179
ccv_nnc_tensor_view_get_dim(saved_inv_std, rdim);
148180
int sdim[CCV_NNC_MAX_DIM_ALLOC];
149-
ccv_nnc_tensor_view_get_dim(scale, sdim);
181+
if (scale)
182+
ccv_nnc_tensor_view_get_dim(scale, sdim);
150183
if (dscale)
151184
{ assert(ccv_nnc_tensor_view_check_dim(dscale, sdim)); }
152185
assert(ccv_nnc_tensor_view_check_dim(a, gdim));
@@ -160,7 +193,8 @@ static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
160193
ccv_nnc_tensor_view_get_stride(a, astride);
161194
ccv_nnc_tensor_view_get_stride(g, gstride);
162195
ccv_nnc_tensor_view_get_stride(h, hstride);
163-
ccv_nnc_tensor_view_get_stride(scale, scale_stride);
196+
if (scale)
197+
ccv_nnc_tensor_view_get_stride(scale, scale_stride);
164198
ccv_nnc_tensor_view_get_stride(saved_inv_std, inv_std_stride);
165199
if (dscale)
166200
ccv_nnc_tensor_view_get_stride(dscale, dscale_stride);
@@ -252,29 +286,55 @@ static int _ccv_nnc_rmsnorm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
252286
} else {
253287
float* gssp = gss;
254288
const float* const gp = g->data.f32;
255-
const float* const scalep = scale->data.f32;
256-
for (i[0] = 0; i[0] < gdim[0]; i[0]++)
289+
if (elementwise_affine)
257290
{
258-
const float* const gp0 = gp + i[0] * gstride[0];
259-
const float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp : inv_stdp + i[0] * inv_std_stride[0];
260-
const float* const scalep0 = sdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0];
261-
for (i[1] = 0; i[1] < gdim[1]; i[1]++)
291+
const float* const scalep = scale->data.f32;
292+
for (i[0] = 0; i[0] < gdim[0]; i[0]++)
262293
{
263-
const float* gp1 = gp0 + i[1] * gstride[1];
264-
const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : inv_stdp0 + i[1] * inv_std_stride[1];
265-
const float* const scalep1 = sdim[1] == 1 ? scalep0 : scalep0 + i[1] * scale_stride[1];
266-
for (i[2] = 0; i[2] < gdim[2]; i[2]++)
294+
const float* const gp0 = gp + i[0] * gstride[0];
295+
const float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp : inv_stdp + i[0] * inv_std_stride[0];
296+
const float* const scalep0 = sdim[0] == 1 ? scalep : scalep + i[0] * scale_stride[0];
297+
for (i[1] = 0; i[1] < gdim[1]; i[1]++)
267298
{
268-
const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : inv_stdp1 + i[2] * inv_std_stride[2];
269-
const float* const scalep2 = sdim[2] == 1 ? scalep1 : scalep1 + i[2] * scale_stride[2];
270-
if (sdim[3] == 1)
271-
for (x = 0; x < gdim[3]; x++)
272-
gssp[x] = gp1[x] * scalep2[0] * inv_stdp2[rdim[3] == 1 ? 0 : x];
273-
else
274-
for (x = 0; x < gdim[3]; x++)
275-
gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[rdim[3] == 1 ? 0 : x];
276-
gp1 += gstride[2];
277-
gssp += gdim[3];
299+
const float* gp1 = gp0 + i[1] * gstride[1];
300+
const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : inv_stdp0 + i[1] * inv_std_stride[1];
301+
const float* const scalep1 = sdim[1] == 1 ? scalep0 : scalep0 + i[1] * scale_stride[1];
302+
for (i[2] = 0; i[2] < gdim[2]; i[2]++)
303+
{
304+
const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : inv_stdp1 + i[2] * inv_std_stride[2];
305+
const float* const scalep2 = sdim[2] == 1 ? scalep1 : scalep1 + i[2] * scale_stride[2];
306+
if (sdim[3] == 1)
307+
for (x = 0; x < gdim[3]; x++)
308+
gssp[x] = gp1[x] * scalep2[0] * inv_stdp2[rdim[3] == 1 ? 0 : x];
309+
else
310+
for (x = 0; x < gdim[3]; x++)
311+
gssp[x] = gp1[x] * scalep2[x] * inv_stdp2[rdim[3] == 1 ? 0 : x];
312+
gp1 += gstride[2];
313+
gssp += gdim[3];
314+
}
315+
}
316+
}
317+
} else {
318+
for (i[0] = 0; i[0] < gdim[0]; i[0]++)
319+
{
320+
const float* const gp0 = gp + i[0] * gstride[0];
321+
const float* const inv_stdp0 = rdim[0] == 1 ? inv_stdp : inv_stdp + i[0] * inv_std_stride[0];
322+
for (i[1] = 0; i[1] < gdim[1]; i[1]++)
323+
{
324+
const float* gp1 = gp0 + i[1] * gstride[1];
325+
const float* const inv_stdp1 = rdim[1] == 1 ? inv_stdp0 : inv_stdp0 + i[1] * inv_std_stride[1];
326+
for (i[2] = 0; i[2] < gdim[2]; i[2]++)
327+
{
328+
const float* const inv_stdp2 = rdim[2] == 1 ? inv_stdp1 : inv_stdp1 + i[2] * inv_std_stride[2];
329+
if (rdim[3] == 1)
330+
for (x = 0; x < gdim[3]; x++)
331+
gssp[x] = gp1[x] * inv_stdp2[0];
332+
else
333+
for (x = 0; x < gdim[3]; x++)
334+
gssp[x] = gp1[x] * inv_stdp2[x];
335+
gp1 += gstride[2];
336+
gssp += gdim[3];
337+
}
278338
}
279339
}
280340
}

0 commit comments

Comments
 (0)