Skip to content

Commit 303cc2f

Browse files
Yuxi Hufacebook-github-bot
authored andcommitted
Support L2 regularization and decoupled weight decay in rowwise adagrad (#718)
Summary: Pull Request resolved: #718 Add two kinds of weight decay in rowwise adagrad: L2 regularization: ``` g' = g + weight_decay * w multiplier = lr / (sqrt(v) + eps) w = w - lr * g' / (sqrt(v) + eps) = w - lr * g / (sqrt(v) + eps) - lr * weight_decay * w / (sqrt(v) + eps) = (1 - multiplier * weight_decay) * w - multiplier * g ``` Decoupled weight decay: ``` multiplier = lr / (sqrt(v) + eps) w = w - lr * (g / (sqrt(v) + eps) + weight_decay * w) = w - lr * g / (sqrt(v) + eps) - lr * weight_decay * w = (1 - lr * weight_decay) * w - multiplier * g ``` Reviewed By: choudharydhruv Differential Revision: D31285351 fbshipit-source-id: e361627f8426856021badef0410455e23620f21b
1 parent 5243fc4 commit 303cc2f

6 files changed

Lines changed: 109 additions & 26 deletions

fbgemm_gpu/codegen/embedding_backward_code_generator.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,49 +367,90 @@ def table_info_precomputation(momentum_prefix: str = "momentum1") -> str:
367367

368368
def rowwise_adagrad() -> None:
369369
split_weight_update = """
370-
weight_new.fma_(grad, -multiplier);
370+
weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x;
371+
weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y;
372+
weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z;
373+
weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w;
371374
"""
372375
split_precomputation = """
373376
acc_type<cache_t, true> g_local_sum_square = 0.0;
374377
#pragma unroll kMaxVecsPerThread
375378
for (int32_t i = 0;
376379
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
377380
++i) {
378-
g_local_sum_square += grad_sum[i].acc.x * grad_sum[i].acc.x +
379-
grad_sum[i].acc.y * grad_sum[i].acc.y +
380-
grad_sum[i].acc.z * grad_sum[i].acc.z +
381-
grad_sum[i].acc.w * grad_sum[i].acc.w;
381+
auto gx = grad_sum[i].acc.x;
382+
auto gy = grad_sum[i].acc.y;
383+
auto gz = grad_sum[i].acc.z;
384+
auto gw = grad_sum[i].acc.w;
385+
if (weight_decay_mode == 0) {
386+
// L2 regularization
387+
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
388+
Vec4T<acc_type<cache_t, true>> weight = weight_row_template.load(d, qparams_template);
389+
gx += weight_decay * weight.acc.x;
390+
gy += weight_decay * weight.acc.y;
391+
gz += weight_decay * weight.acc.z;
392+
gw += weight_decay * weight.acc.w;
393+
}
394+
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
382395
}
383396
const acc_type<cache_t, true> g_avg_square =
384397
warpReduceAllSum<acc_type<cache_t, true>>(g_local_sum_square) / D;
385398
386399
acc_type<cache_t, true> multiplier;
400+
acc_type<cache_t, true> correction = 1.0;
387401
if (threadIdx.x == 0) {
388402
acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
389403
momentum1[idx] = new_sum_square_grads;
390404
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
405+
if (weight_decay_mode == 0) {
406+
// L2 regularization
407+
correction = 1 - multiplier * weight_decay;
408+
} else if (weight_decay_mode == 1){
409+
// Decoupled weight decay
410+
correction = 1 - learning_rate * weight_decay;
411+
}
391412
}
392413
multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0);
414+
correction = __shfl_sync(0xFFFFFFFF, correction, 0);
393415
"""
394416
split_weight_update_cpu = """
395417
acc_type<scalar_t, true> g_local_sum_square = 0.0;
396418
for (int64_t d = 0; d < D; ++d) {
397-
g_local_sum_square += grad_buffer[d] * grad_buffer[d];
419+
auto grad = grad_buffer[d];
420+
if (weight_decay_mode == 0) {
421+
// L2 regularization
422+
grad += weight_decay * host_weights_data[embedding_begin + d];
423+
}
424+
g_local_sum_square += grad * grad;
398425
}
399426
auto g_avg_square = g_local_sum_square / D;
400427
acc_type<scalar_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square;
401428
momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads;
402429
acc_type<scalar_t, true> multiplier;
403430
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
431+
acc_type<scalar_t, true> correction = 1.0;
432+
if (weight_decay_mode == 0) {
433+
// L2 regularization
434+
correction = 1 - multiplier * weight_decay;
435+
} else if (weight_decay_mode == 1) {
436+
// Decoupled weight decay
437+
correction = 1 - learning_rate * weight_decay;
438+
}
404439
for (int64_t d = 0; d < D; ++d) {
405-
host_weights_data[embedding_begin + d] -= grad_buffer[d] * multiplier;
440+
host_weights_data[embedding_begin + d] = correction * host_weights_data[embedding_begin + d] - grad_buffer[d] * multiplier;
406441
}
407442
"""
408443

409444
generate(
410445
optimizer="rowwise_adagrad",
411446
args=make_args(
412-
[(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")]
447+
[
448+
(TENSOR, "momentum1"),
449+
(FLOAT, "eps"),
450+
(FLOAT, "learning_rate"),
451+
(FLOAT, "weight_decay"),
452+
(INT, "weight_decay_mode"),
453+
]
413454
),
414455
split_precomputation=split_precomputation,
415456
split_weight_update=split_weight_update,
@@ -425,7 +466,13 @@ def rowwise_adagrad() -> None:
425466
generate(
426467
optimizer="approx_rowwise_adagrad",
427468
args=make_args(
428-
[(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")]
469+
[
470+
(TENSOR, "momentum1"),
471+
(FLOAT, "eps"),
472+
(FLOAT, "learning_rate"),
473+
(FLOAT, "weight_decay"),
474+
(INT, "weight_decay_mode"),
475+
]
429476
),
430477
split_precomputation=split_precomputation,
431478
split_weight_update=approx_split_weight_update,

fbgemm_gpu/codegen/embedding_backward_split_template.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,6 @@ split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_
292292
}
293293
{% endfor %}
294294

295-
{{ split_precomputation }}
296-
297295
struct SharedMemory<Vec4T<acc_type<cache_t, true>>> weight_update_buffer;
298296
Vec4T<acc_type<cache_t, true>>* shared_weight_update_row = weight_update_buffer.getPointer();
299297

@@ -315,6 +313,9 @@ split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_
315313
if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
316314
qparams_template = weight_row_template.load_qparams();
317315
}
316+
317+
{{ split_precomputation }}
318+
318319
float2 qparams_new;
319320
#pragma unroll kMaxVecsPerThread
320321
for (int32_t i = 0;
@@ -506,7 +507,6 @@ split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row
506507
}
507508
{% endfor %}
508509

509-
{{ split_precomputation }}
510510
struct SharedMemory<Vec4T<acc_type<cache_t, true>>> weight_update_buffer;
511511
Vec4T<acc_type<cache_t, true>>* shared_weight_update_row = weight_update_buffer.getPointer();
512512
auto weight_row_template = WeightRow<emb_t, cache_t, acc_type<cache_t, true>>(weights, cache_weights, D, nullptr);
@@ -526,6 +526,9 @@ split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row
526526
if (std::is_same<emb_t, uint8_t>::value && !cache_weights){
527527
qparams_template = weight_row_template.load_qparams();
528528
}
529+
530+
{{ split_precomputation }}
531+
529532
float2 qparams_new;
530533
#pragma unroll kMaxVecsPerThread
531534
for (int32_t i = 0;

fbgemm_gpu/codegen/lookup_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class OptimizerArgs(NamedTuple):
3939
beta1: float
4040
beta2: float
4141
weight_decay: float
42+
weight_decay_mode: int
4243
eta: float
4344
momentum: float
4445

fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def invoke(
7171
{% if "weight_decay" in args.split_function_arg_names %}
7272
weight_decay=optimizer_args.weight_decay,
7373
{% endif %}
74+
{% if "weight_decay_mode" in args.split_function_arg_names %}
75+
weight_decay_mode=optimizer_args.weight_decay_mode,
76+
{% endif %}
7477
{% if "eta" in args.split_function_arg_names %}
7578
eta=optimizer_args.eta,
7679
{% endif %}
@@ -135,6 +138,9 @@ def invoke(
135138
{% if "weight_decay" in args.split_function_arg_names %}
136139
weight_decay=optimizer_args.weight_decay,
137140
{% endif %}
141+
{% if "weight_decay_mode" in args.split_function_arg_names %}
142+
weight_decay_mode=optimizer_args.weight_decay_mode,
143+
{% endif %}
138144
{% if "eta" in args.split_function_arg_names %}
139145
eta=optimizer_args.eta,
140146
{% endif %}

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class BoundsCheckMode(enum.IntEnum):
6565
NONE = 3
6666

6767

68+
class WeightDecayMode(enum.IntEnum):
69+
L2 = 0
70+
DECOUPLE = 1
71+
72+
6873
RecordCacheMetrics: NamedTuple = NamedTuple(
6974
"RecordCacheMetrics",
7075
[("record_cache_miss_counter", bool), ("record_tablewise_cache_miss", bool)],
@@ -201,7 +206,8 @@ def __init__( # noqa C901
201206
learning_rate: float = 0.01,
202207
eps: float = 1.0e-8, # used by Adagrad, LAMB, and Adam
203208
momentum: float = 0.9, # used by LARS-SGD
204-
weight_decay: float = 0.0, # used by LARS-SGD, LAMB, and ADAM
209+
weight_decay: float = 0.0, # used by LARS-SGD, LAMB, Adagrad, and ADAM
210+
weight_decay_mode: WeightDecayMode = WeightDecayMode.L2, # used by Adagrad
205211
eta: float = 0.001, # used by LARS-SGD,
206212
beta1: float = 0.9, # used by LAMB and ADAM
207213
beta2: float = 0.999, # used by LAMB and ADAM
@@ -357,6 +363,7 @@ def __init__( # noqa C901
357363
beta1=beta1,
358364
beta2=beta2,
359365
weight_decay=weight_decay,
366+
weight_decay_mode=weight_decay_mode.value,
360367
eta=eta,
361368
momentum=momentum,
362369
)
@@ -493,7 +500,7 @@ def __init__( # noqa C901
493500
dtype=cache_embedding_dtype,
494501
)
495502

496-
logging.debug(
503+
logging.info(
497504
f"Using fused {optimizer} with optimizer_args={self.optimizer_args}"
498505
)
499506

fbgemm_gpu/test/split_table_batched_embeddings_test.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import numpy as np
1717
import torch
1818
from fbgemm_gpu.split_table_batched_embeddings_ops import (
19+
BoundsCheckMode,
1920
OptimType,
20-
SparseType,
2121
RecordCacheMetrics,
22-
BoundsCheckMode,
22+
SparseType,
23+
WeightDecayMode,
2324
)
2425
from hypothesis import HealthCheck, Verbosity, assume, given, settings
2526
from torch import Tensor
@@ -1409,6 +1410,7 @@ def execute_backward_optimizers_( # noqa C901
14091410
long_segments: bool,
14101411
pooling_mode: split_table_batched_embeddings_ops.PoolingMode,
14111412
use_cpu: bool,
1413+
weight_decay_mode: WeightDecayMode = WeightDecayMode.L2,
14121414
) -> None:
14131415
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
14141416
assume(not use_cpu or T * B * L * D <= 2048)
@@ -1534,6 +1536,10 @@ def execute_backward_optimizers_( # noqa C901
15341536
if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_ADAGRAD):
15351537
optimizer_kwargs["eps"] = eps
15361538

1539+
if optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
1540+
optimizer_kwargs["weight_decay"] = weight_decay
1541+
optimizer_kwargs["weight_decay_mode"] = weight_decay_mode
1542+
15371543
if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM):
15381544
optimizer_kwargs["eps"] = eps
15391545
optimizer_kwargs["beta1"] = beta1
@@ -1591,25 +1597,30 @@ def execute_backward_optimizers_( # noqa C901
15911597
# to_dense in GPU is non-deterministic due to atmomics used in
15921598
# coalescing and floating point non-associativity.
15931599
dense_cpu_grad = bs[t].weight.grad.cpu().to_dense()
1600+
if rowwise and not use_cpu and weight_decay_mode == WeightDecayMode.L2:
1601+
# NOTE: CPU code path (https://fburl.com/diffusion/rte4cu6c) is not executed in unit test.
1602+
dense_cpu_grad += weight_decay * bs[t].weight.cpu()
15941603
m1_ref = (
15951604
dense_cpu_grad.pow(2)
15961605
if not rowwise
15971606
else dense_cpu_grad.pow(2).mean(dim=1)
15981607
)
15991608
torch.testing.assert_allclose(
1600-
m1.float().cpu(), m1_ref.float(), atol=1.0e-4, rtol=1.0e-4
1609+
m1.float().index_select(dim=0, index=x[t].view(-1)).cpu(),
1610+
m1_ref.float().index_select(dim=0, index=x[t].view(-1).cpu()),
1611+
atol=1.0e-4,
1612+
rtol=1.0e-4
16011613
)
16021614
weights_new = split_weights[t]
1603-
weights_ref = bs[t].weight.cpu() - lr * dense_cpu_grad / (
1604-
torch.sqrt(
1605-
m1_ref if not rowwise else m1_ref.view(m1_ref.numel(), 1)
1606-
)
1607-
+ eps
1608-
)
1615+
denom = torch.sqrt(m1_ref if not rowwise else m1_ref.view(m1_ref.numel(), 1)) + eps
1616+
if rowwise and not use_cpu and weight_decay_mode == WeightDecayMode.DECOUPLE:
1617+
weights_ref = bs[t].weight.cpu() - lr * (dense_cpu_grad / denom + weight_decay * bs[t].weight.cpu())
1618+
else:
1619+
weights_ref = bs[t].weight.cpu() - lr * dense_cpu_grad / denom
16091620
# TODO: why is tolerance off here?
16101621
torch.testing.assert_allclose(
1611-
weights_new.float().cpu(),
1612-
weights_ref.float(),
1622+
weights_new.index_select(dim=0, index=x[t].view(-1)).cpu(),
1623+
weights_ref.index_select(dim=0, index=x[t].view(-1).cpu()),
16131624
atol=1.0e-2,
16141625
rtol=1.0e-2,
16151626
)
@@ -1793,6 +1804,12 @@ def test_backward_optimizers_adam( # noqa C901
17931804
]
17941805
),
17951806
use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True),
1807+
weight_decay_mode=st.sampled_from(
1808+
[
1809+
WeightDecayMode.L2,
1810+
WeightDecayMode.DECOUPLE,
1811+
]
1812+
),
17961813
)
17971814
@settings(
17981815
verbosity=Verbosity.verbose,
@@ -1813,9 +1830,11 @@ def test_backward_optimizers_adagrad( # noqa C901
18131830
long_segments: bool,
18141831
pooling_mode: split_table_batched_embeddings_ops.PoolingMode,
18151832
use_cpu: bool,
1833+
weight_decay_mode: WeightDecayMode,
18161834
) -> None:
18171835
self.execute_backward_optimizers_(T, D, B, log_E, L, weighted,
1818-
mixed, optimizer, long_segments, pooling_mode, use_cpu)
1836+
mixed, optimizer, long_segments, pooling_mode, use_cpu,
1837+
weight_decay_mode)
18191838

18201839
@given(
18211840
T=st.integers(min_value=1, max_value=5),

0 commit comments

Comments
 (0)