@@ -367,49 +367,90 @@ def table_info_precomputation(momentum_prefix: str = "momentum1") -> str:
367367
368368def rowwise_adagrad () -> None :
369369 split_weight_update = """
370- weight_new.fma_(grad, -multiplier);
370+ weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x;
371+ weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y;
372+ weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z;
373+ weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w;
371374 """
372375 split_precomputation = """
373376 acc_type<cache_t, true> g_local_sum_square = 0.0;
374377 #pragma unroll kMaxVecsPerThread
375378 for (int32_t i = 0;
376379 i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
377380 ++i) {
378- g_local_sum_square += grad_sum[i].acc.x * grad_sum[i].acc.x +
379- grad_sum[i].acc.y * grad_sum[i].acc.y +
380- grad_sum[i].acc.z * grad_sum[i].acc.z +
381- grad_sum[i].acc.w * grad_sum[i].acc.w;
381+ auto gx = grad_sum[i].acc.x;
382+ auto gy = grad_sum[i].acc.y;
383+ auto gz = grad_sum[i].acc.z;
384+ auto gw = grad_sum[i].acc.w;
385+ if (weight_decay_mode == 0) {
386+ // L2 regularization
387+ int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
388+ Vec4T<acc_type<cache_t, true>> weight = weight_row_template.load(d, qparams_template);
389+ gx += weight_decay * weight.acc.x;
390+ gy += weight_decay * weight.acc.y;
391+ gz += weight_decay * weight.acc.z;
392+ gw += weight_decay * weight.acc.w;
393+ }
394+ g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
382395 }
383396 const acc_type<cache_t, true> g_avg_square =
384397 warpReduceAllSum<acc_type<cache_t, true>>(g_local_sum_square) / D;
385398
386399 acc_type<cache_t, true> multiplier;
400+ acc_type<cache_t, true> correction = 1.0;
387401 if (threadIdx.x == 0) {
388402 acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
389403 momentum1[idx] = new_sum_square_grads;
390404 multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
405+ if (weight_decay_mode == 0) {
406+ // L2 regularization
407+ correction = 1 - multiplier * weight_decay;
408+ } else if (weight_decay_mode == 1){
409+ // Decoupled weight decay
410+ correction = 1 - learning_rate * weight_decay;
411+ }
391412 }
392413 multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0);
414+ correction = __shfl_sync(0xFFFFFFFF, correction, 0);
393415 """
394416 split_weight_update_cpu = """
395417 acc_type<scalar_t, true> g_local_sum_square = 0.0;
396418 for (int64_t d = 0; d < D; ++d) {
397- g_local_sum_square += grad_buffer[d] * grad_buffer[d];
419+ auto grad = grad_buffer[d];
420+ if (weight_decay_mode == 0) {
421+ // L2 regularization
422+ grad += weight_decay * host_weights_data[embedding_begin + d];
423+ }
424+ g_local_sum_square += grad * grad;
398425 }
399426 auto g_avg_square = g_local_sum_square / D;
400427 acc_type<scalar_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square;
401428 momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads;
402429 acc_type<scalar_t, true> multiplier;
403430 multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
431+ acc_type<scalar_t, true> correction = 1.0;
432+ if (weight_decay_mode == 0) {
433+ // L2 regularization
434+ correction = 1 - multiplier * weight_decay;
435+ } else if (weight_decay_mode == 1) {
436+ // Decoupled weight decay
437+ correction = 1 - learning_rate * weight_decay;
438+ }
404439 for (int64_t d = 0; d < D; ++d) {
405- host_weights_data[embedding_begin + d] -= grad_buffer[d] * multiplier;
440+ host_weights_data[embedding_begin + d] = correction * host_weights_data[embedding_begin + d] - grad_buffer[d] * multiplier;
406441 }
407442 """
408443
409444 generate (
410445 optimizer = "rowwise_adagrad" ,
411446 args = make_args (
412- [(TENSOR , "momentum1" ), (FLOAT , "eps" ), (FLOAT , "learning_rate" )]
447+ [
448+ (TENSOR , "momentum1" ),
449+ (FLOAT , "eps" ),
450+ (FLOAT , "learning_rate" ),
451+ (FLOAT , "weight_decay" ),
452+ (INT , "weight_decay_mode" ),
453+ ]
413454 ),
414455 split_precomputation = split_precomputation ,
415456 split_weight_update = split_weight_update ,
@@ -425,7 +466,13 @@ def rowwise_adagrad() -> None:
425466 generate (
426467 optimizer = "approx_rowwise_adagrad" ,
427468 args = make_args (
428- [(TENSOR , "momentum1" ), (FLOAT , "eps" ), (FLOAT , "learning_rate" )]
469+ [
470+ (TENSOR , "momentum1" ),
471+ (FLOAT , "eps" ),
472+ (FLOAT , "learning_rate" ),
473+ (FLOAT , "weight_decay" ),
474+ (INT , "weight_decay_mode" ),
475+ ]
429476 ),
430477 split_precomputation = split_precomputation ,
431478 split_weight_update = approx_split_weight_update ,
0 commit comments