Skip to content

Commit daacc71

Browse files
authored
[feat] add multiple ranking losses & loss weight adaptive learning & fix a few bugs (alibaba#358)
[feat] add multiple ranking losses & loss weight adaptive learning & fix a few bugs
1 parent 53e703f commit daacc71

File tree

17 files changed

+702
-110
lines changed

17 files changed

+702
-110
lines changed

docs/source/train.md

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,21 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
8080

8181
#### 使用单个损失函数
8282

83-
| 损失函数 | 说明 |
83+
| 损失函数 | 说明 |
8484
| ------------------------------------------ | ---------------------------------------------------------- |
8585
| CLASSIFICATION | 分类Loss,二分类为sigmoid_cross_entropy;多分类为softmax_cross_entropy |
8686
| L2_LOSS | 平方损失 |
87-
| SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 |
87+
| SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 |
8888
| CROSS_ENTROPY_LOSS | log loss 负对数损失 |
8989
| CIRCLE_LOSS | CoMetricLearningI2I模型专用 |
9090
| MULTI_SIMILARITY_LOSS | CoMetricLearningI2I模型专用 |
91-
| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 |
92-
| PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss |
93-
| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 |
91+
| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 |
92+
| BINARY_FOCAL_LOSS | 支持困难样本挖掘和类别平衡的focal loss |
93+
| PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss |
94+
| PAIRWISE_FOCAL_LOSS | pair粒度的focal loss, 支持自定义pair分组 |
95+
| PAIRWISE_LOGISTIC_LOSS | pair粒度的logistic loss, 支持自定义pair分组 |
96+
| JRC_LOSS | 二分类 + listwise ranking loss |
97+
| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 |
9498

9599
- 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
96100
- 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317)
@@ -148,9 +152,72 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
148152
- ![f_beta score](../images/other/f_beta_score.svg)
149153
- f1_beta_square 即为 上述公式中的 beta 系数的平方。
150154

155+
- PAIRWISE_FOCAL_LOSS 的参数配置
156+
- gamma: focal loss的指数,默认值2.0
157+
- alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\frac{\alpha}{1-\alpha}=\frac{#Neg}{#Pos}$
158+
- session_name: pair分组的字段名,比如user_id
159+
- hinge_margin: 当pair的logit之差大于该参数值时,当前样本的loss为0,默认值为1.0
160+
- ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0
161+
- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0
162+
163+
- PAIRWISE_LOGISTIC_LOSS 的参数配置
164+
- session_name: pair分组的字段名,比如user_id
165+
- hinge_margin: 当pair的logit之差大于该参数值时,当前样本的loss为0,默认值为1.0
166+
- ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0
167+
- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0
168+
169+
- PAIRWISE_LOSS 的参数配置
170+
- session_name: pair分组的字段名,比如user_id
171+
- margin: 当pair的logit之差减去该参数值后再参与计算,即正负样本的logit之差至少要大于margin,默认值为0
172+
- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0
173+
174+
备注:上述 PAIRWISE_*_LOSS 都是在mini-batch内构建正负样本pair,目标是让正负样本pair的logit相差尽可能大
175+
176+
- BINARY_FOCAL_LOSS 的参数配置
177+
- gamma: focal loss的指数,默认值2.0
178+
- alpha: 调节样本权重的类别平衡参数,建议根据正负样本比例来配置alpha, $\frac{\alpha}{1-\alpha}=\frac{#Neg}{#Pos}$
179+
- ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0
180+
- label_smoothing: 标签平滑系数
181+
182+
- JRC_LOSS 的参数配置
183+
- alpha: ranking loss 与 calibration loss 的相对权重系数;不设置该值时,触发权重自适应学习
184+
- session_name: list分组的字段名,比如user_id
185+
- 参考论文:《 [Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model](https://arxiv.org/pdf/2208.06164.pdf)
186+
151187
排序模型同时使用多个损失函数的完整示例:
152188
[cmbf_with_multi_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/cmbf_with_multi_loss.config)
153189

190+
##### 损失函数权重自适应学习
191+
192+
多目标学习任务中,人工指定多个损失函数的静态权重通常不能获得最好的效果。EasyRec支持损失函数权重自适应学习,示例如下:
193+
194+
```protobuf
195+
losses {
196+
loss_type: CLASSIFICATION
197+
learn_loss_weight: true
198+
}
199+
losses {
200+
loss_type: BINARY_FOCAL_LOSS
201+
learn_loss_weight: true
202+
binary_focal_loss {
203+
gamma: 2.0
204+
alpha: 0.85
205+
}
206+
}
207+
losses {
208+
loss_type: PAIRWISE_FOCAL_LOSS
209+
learn_loss_weight: true
210+
pairwise_focal_loss {
211+
session_name: "client_str"
212+
hinge_margin: 1.0
213+
}
214+
}
215+
```
216+
217+
通过`learn_loss_weight`参数配置是否需要开启权重自适应学习,默认不开启。开启之后,`weight`参数不再生效。
218+
219+
参考论文:《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》
220+
154221
## 训练命令
155222

156223
### Local

easy_rec/python/builders/loss_builder.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
import tensorflow as tf
66

7+
from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
8+
from easy_rec.python.loss.jrc_loss import jrc_loss
9+
from easy_rec.python.loss.pairwise_loss import pairwise_focal_loss
10+
from easy_rec.python.loss.pairwise_loss import pairwise_logistic_loss
711
from easy_rec.python.loss.pairwise_loss import pairwise_loss
812
from easy_rec.python.protos.loss_pb2 import LossType
913

@@ -20,6 +24,7 @@ def build(loss_type,
2024
num_class=1,
2125
loss_param=None,
2226
**kwargs):
27+
loss_name = kwargs.pop('loss_name') if 'loss_name' in kwargs else 'unknown'
2328
if loss_type == LossType.CLASSIFICATION:
2429
if num_class == 1:
2530
return tf.losses.sigmoid_cross_entropy(
@@ -35,8 +40,60 @@ def build(loss_type,
3540
logging.info('%s is used' % LossType.Name(loss_type))
3641
return tf.losses.mean_squared_error(
3742
labels=label, predictions=pred, weights=loss_weight, **kwargs)
43+
elif loss_type == LossType.JRC_LOSS:
44+
alpha = 0.5 if loss_param is None else loss_param.alpha
45+
auto_weight = False if loss_param is None else not loss_param.HasField(
46+
'alpha')
47+
session = kwargs.get('session_ids', None)
48+
return jrc_loss(
49+
label, pred, session, alpha, auto_weight=auto_weight, name=loss_name)
3850
elif loss_type == LossType.PAIR_WISE_LOSS:
39-
return pairwise_loss(label, pred)
51+
session = kwargs.get('session_ids', None)
52+
margin = 0 if loss_param is None else loss_param.margin
53+
temp = 1.0 if loss_param is None else loss_param.temperature
54+
return pairwise_loss(
55+
label,
56+
pred,
57+
session_ids=session,
58+
margin=margin,
59+
temperature=temp,
60+
weights=loss_weight,
61+
name=loss_name)
62+
elif loss_type == LossType.PAIRWISE_LOGISTIC_LOSS:
63+
session = kwargs.get('session_ids', None)
64+
temp = 1.0 if loss_param is None else loss_param.temperature
65+
ohem_ratio = 1.0 if loss_param is None else loss_param.ohem_ratio
66+
hinge_margin = None
67+
if loss_param is not None and loss_param.HasField('hinge_margin'):
68+
hinge_margin = loss_param.hinge_margin
69+
return pairwise_logistic_loss(
70+
label,
71+
pred,
72+
session_ids=session,
73+
temperature=temp,
74+
hinge_margin=hinge_margin,
75+
ohem_ratio=ohem_ratio,
76+
weights=loss_weight,
77+
name=loss_name)
78+
elif loss_type == LossType.PAIRWISE_FOCAL_LOSS:
79+
session = kwargs.get('session_ids', None)
80+
if loss_param is None:
81+
return pairwise_focal_loss(
82+
label, pred, session_ids=session, weights=loss_weight, name=loss_name)
83+
hinge_margin = None
84+
if loss_param.HasField('hinge_margin'):
85+
hinge_margin = loss_param.hinge_margin
86+
return pairwise_focal_loss(
87+
label,
88+
pred,
89+
session_ids=session,
90+
gamma=loss_param.gamma,
91+
alpha=loss_param.alpha if loss_param.HasField('alpha') else None,
92+
hinge_margin=hinge_margin,
93+
ohem_ratio=loss_param.ohem_ratio,
94+
temperature=loss_param.temperature,
95+
weights=loss_weight,
96+
name=loss_name)
4097
elif loss_type == LossType.F1_REWEIGHTED_LOSS:
4198
f1_beta_square = 1.0 if loss_param is None else loss_param.f1_beta_square
4299
label_smoothing = 0 if loss_param is None else loss_param.label_smoothing
@@ -46,6 +103,23 @@ def build(loss_type,
46103
f1_beta_square,
47104
weights=loss_weight,
48105
label_smoothing=label_smoothing)
106+
elif loss_type == LossType.BINARY_FOCAL_LOSS:
107+
if loss_param is None:
108+
return sigmoid_focal_loss_with_logits(
109+
label, pred, sample_weights=loss_weight, name=loss_name)
110+
gamma = loss_param.gamma
111+
alpha = None
112+
if loss_param.HasField('alpha'):
113+
alpha = loss_param.alpha
114+
return sigmoid_focal_loss_with_logits(
115+
label,
116+
pred,
117+
gamma=gamma,
118+
alpha=alpha,
119+
ohem_ratio=loss_param.ohem_ratio,
120+
sample_weights=loss_weight,
121+
label_smoothing=loss_param.label_smoothing,
122+
name=loss_name)
49123
else:
50124
raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type))
51125

easy_rec/python/compat/weight_decay_optimizers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def __init__(self,
411411

412412

413413
try:
414-
from tensorflow.python.training import AdamAsyncOptimizer
414+
from tensorflow.train import AdamAsyncOptimizer
415415

416416
@tf_export('contrib.opt.AdamAsyncWOptimizer')
417417
class AdamAsyncWOptimizer(DecoupledWeightDecayExtension, AdamAsyncOptimizer):
@@ -472,4 +472,5 @@ def __init__(self,
472472
use_locking=use_locking,
473473
name=name)
474474
except ImportError:
475+
print('import AdamAsyncOptimizer failed when loading AdamAsyncWOptimizer')
475476
pass

0 commit comments

Comments
 (0)