Skip to content

Support L2 regularization and decoupled weight decay in rowwise adagrad#718

Open
yuxihu wants to merge 1 commit intopytorch:mainfrom
yuxihu:export-D31285351
Open

Support L2 regularization and decoupled weight decay in rowwise adagrad#718
yuxihu wants to merge 1 commit intopytorch:mainfrom
yuxihu:export-D31285351

Conversation

@yuxihu
Copy link

@yuxihu yuxihu commented Oct 5, 2021

Summary:
Add two kinds of weight decay in rowwise adagrad:

L2 regularization:

g' = g + weight_decay * w
multiplier = lr / (sqrt(v) + eps)
w = w - lr * g' / (sqrt(v) + eps)
    = w - lr * g / (sqrt(v) + eps) - lr * weight_decay * w / (sqrt(v) + eps)
    = (1 - multiplier * weight_decay) * w - multiplier * g

Decoupled weight decay:

multiplier = lr / (sqrt(v) + eps)
w = w - lr * (g / (sqrt(v) + eps) + weight_decay * w)
    = w - lr * g / (sqrt(v) + eps) - lr * weight_decay * w
    = (1 - lr * weight_decay) * w - multiplier * g

Reviewed By: choudharydhruv

Differential Revision: D31285351

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31285351

yuxihu pushed a commit to yuxihu/FBGEMM that referenced this pull request Oct 5, 2021
…ad (pytorch#718)

Summary:
Pull Request resolved: pytorch#718

Add two kinds of weight decay in rowwise adagrad:

L2 regularization:
```
g' = g + weight_decay * w
multiplier = lr / (sqrt(v) + eps)
w = w - lr * g' / (sqrt(v) + eps)
    = w - lr * g / (sqrt(v) + eps) - lr * weight_decay * w / (sqrt(v) + eps)
    = (1 - multiplier * weight_decay) * w - multiplier * g
```

Decoupled weight decay:
```
multiplier = lr / (sqrt(v) + eps)
w = w - lr * (g / (sqrt(v) + eps) + weight_decay * w)
    = w - lr * g / (sqrt(v) + eps) - lr * weight_decay * w
    = (1 - lr * weight_decay) * w - multiplier * g
```

Reviewed By: choudharydhruv

Differential Revision: D31285351

fbshipit-source-id: 7ee3cf014fff05d837d34ec0c2b67e189272f502
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31285351

…ad (pytorch#718)

Summary:
Pull Request resolved: pytorch#718

Add two kinds of weight decay in rowwise adagrad:

L2 regularization:
```
g' = g + weight_decay * w
multiplier = lr / (sqrt(v) + eps)
w = w - lr * g' / (sqrt(v) + eps)
    = w - lr * g / (sqrt(v) + eps) - lr * weight_decay * w / (sqrt(v) + eps)
    = (1 - multiplier * weight_decay) * w - multiplier * g
```

Decoupled weight decay:
```
multiplier = lr / (sqrt(v) + eps)
w = w - lr * (g / (sqrt(v) + eps) + weight_decay * w)
    = w - lr * g / (sqrt(v) + eps) - lr * weight_decay * w
    = (1 - lr * weight_decay) * w - multiplier * g
```

Reviewed By: choudharydhruv

Differential Revision: D31285351

fbshipit-source-id: e361627f8426856021badef0410455e23620f21b
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31285351

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants