Skip to content

Commit 3f0d697

Browse files
authored
[bug fix]: fix bug of multiple losses of rank model (alibaba#335)
* [bug fix]: fix bug of multiple losses of rank model & fix code style check problems
1 parent 6c8591b commit 3f0d697

File tree

2 files changed

+81
-36
lines changed

2 files changed

+81
-36
lines changed

easy_rec/python/model/multi_task_model.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,35 @@ def _init_towers(self, task_tower_configs):
5151
def _add_to_prediction_dict(self, output):
5252
for task_tower_cfg in self._task_towers:
5353
tower_name = task_tower_cfg.tower_name
54-
self._prediction_dict.update(
55-
self._output_to_prediction_impl(
56-
output[tower_name],
57-
loss_type=task_tower_cfg.loss_type,
58-
num_class=task_tower_cfg.num_class,
59-
suffix='_%s' % tower_name))
54+
if len(task_tower_cfg.losses) == 0:
55+
self._prediction_dict.update(
56+
self._output_to_prediction_impl(
57+
output[tower_name],
58+
loss_type=task_tower_cfg.loss_type,
59+
num_class=task_tower_cfg.num_class,
60+
suffix='_%s' % tower_name))
61+
else:
62+
for loss in task_tower_cfg.losses:
63+
self._prediction_dict.update(
64+
self._output_to_prediction_impl(
65+
output[tower_name],
66+
loss_type=loss.loss_type,
67+
num_class=task_tower_cfg.num_class,
68+
suffix='_%s' % tower_name))
6069

6170
def build_metric_graph(self, eval_config):
6271
"""Build metric graph for multi task model."""
6372
metric_dict = {}
6473
for task_tower_cfg in self._task_towers:
6574
tower_name = task_tower_cfg.tower_name
6675
for metric in task_tower_cfg.metrics_set:
76+
loss_types = {task_tower_cfg.loss_type}
77+
if len(task_tower_cfg.losses) > 0:
78+
loss_types = {loss.loss_type for loss in task_tower_cfg.losses}
6779
metric_dict.update(
6880
self._build_metric_impl(
6981
metric,
70-
loss_type=task_tower_cfg.loss_type,
82+
loss_type=loss_types,
7183
label_name=self._label_name_dict[tower_name],
7284
num_class=task_tower_cfg.num_class,
7385
suffix='_%s' % tower_name))
@@ -123,9 +135,17 @@ def get_outputs(self):
123135
outputs = []
124136
for task_tower_cfg in self._task_towers:
125137
tower_name = task_tower_cfg.tower_name
126-
outputs.extend(
127-
self._get_outputs_impl(
128-
task_tower_cfg.loss_type,
129-
task_tower_cfg.num_class,
130-
suffix='_%s' % tower_name))
131-
return outputs
138+
if len(task_tower_cfg.losses) == 0:
139+
outputs.extend(
140+
self._get_outputs_impl(
141+
task_tower_cfg.loss_type,
142+
task_tower_cfg.num_class,
143+
suffix='_%s' % tower_name))
144+
else:
145+
for loss in task_tower_cfg.losses:
146+
outputs.extend(
147+
self._get_outputs_impl(
148+
loss.loss_type,
149+
task_tower_cfg.num_class,
150+
suffix='_%s' % tower_name))
151+
return list(set(outputs))

easy_rec/python/model/rank_model.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,15 @@ def _output_to_prediction_impl(self,
6666
return prediction_dict
6767

6868
def _add_to_prediction_dict(self, output):
69-
prediction_dict = self._output_to_prediction_impl(
70-
output, loss_type=self._loss_type, num_class=self._num_class)
71-
self._prediction_dict.update(prediction_dict)
69+
if len(self._losses) == 0:
70+
prediction_dict = self._output_to_prediction_impl(
71+
output, loss_type=self._loss_type, num_class=self._num_class)
72+
self._prediction_dict.update(prediction_dict)
73+
else:
74+
for loss in self._losses:
75+
prediction_dict = self._output_to_prediction_impl(
76+
output, loss_type=loss.loss_type, num_class=self._num_class)
77+
self._prediction_dict.update(prediction_dict)
7278

7379
def build_rtp_output_dict(self):
7480
"""Forward tensor as `rank_predict`, which is a special node for RTP."""
@@ -85,15 +91,22 @@ def build_rtp_output_dict(self):
8591
rank_predict = op.outputs[0]
8692
except KeyError:
8793
forwarded = None
88-
if self._loss_type == LossType.CLASSIFICATION:
94+
loss_types = {self._loss_type}
95+
if len(self._losses) > 0:
96+
loss_types = {loss.loss_type for loss in self._losses}
97+
binary_loss_set = {
98+
LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
99+
LossType.PAIR_WISE_LOSS
100+
}
101+
if loss_types & binary_loss_set:
89102
if 'probs' in self._prediction_dict:
90103
forwarded = self._prediction_dict['probs']
91104
else:
92105
raise ValueError(
93106
'failed to build RTP rank_predict output: classification model ' +
94107
"expect 'probs' prediction, which is not found. Please check if" +
95108
' build_predict_graph() is called.')
96-
elif self._loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
109+
elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
97110
if 'y' in self._prediction_dict:
98111
forwarded = self._prediction_dict['y']
99112
else:
@@ -105,7 +118,7 @@ def build_rtp_output_dict(self):
105118
else:
106119
logging.warning(
107120
'failed to build RTP rank_predict: unsupported loss type {}'.foramt(
108-
self._loss_type))
121+
loss_types))
109122
if forwarded is not None:
110123
rank_predict = tf.identity(forwarded, name='rank_predict')
111124
if rank_predict is not None:
@@ -183,6 +196,8 @@ def _build_metric_impl(self,
183196
label_name,
184197
num_class=1,
185198
suffix=''):
199+
if not isinstance(loss_type, set):
200+
loss_type = {loss_type}
186201
from easy_rec.python.core.easyrec_metrics import metrics_tf
187202
from easy_rec.python.core import metrics as metrics_lib
188203
binary_loss_set = {
@@ -191,8 +206,7 @@ def _build_metric_impl(self,
191206
}
192207
metric_dict = {}
193208
if metric.WhichOneof('metric') == 'auc':
194-
assert loss_type in binary_loss_set
195-
209+
assert loss_type & binary_loss_set
196210
if num_class == 1:
197211
label = tf.to_int64(self._labels[label_name])
198212
metric_dict['auc' + suffix] = metrics_tf.auc(
@@ -208,7 +222,7 @@ def _build_metric_impl(self,
208222
else:
209223
raise ValueError('Wrong class number')
210224
elif metric.WhichOneof('metric') == 'gauc':
211-
assert loss_type in binary_loss_set
225+
assert loss_type & binary_loss_set
212226
if num_class == 1:
213227
label = tf.to_int64(self._labels[label_name])
214228
uids = self._feature_dict[metric.gauc.uid_field]
@@ -231,7 +245,7 @@ def _build_metric_impl(self,
231245
else:
232246
raise ValueError('Wrong class number')
233247
elif metric.WhichOneof('metric') == 'session_auc':
234-
assert loss_type in binary_loss_set
248+
assert loss_type & binary_loss_set
235249
if num_class == 1:
236250
label = tf.to_int64(self._labels[label_name])
237251
metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
@@ -249,7 +263,7 @@ def _build_metric_impl(self,
249263
else:
250264
raise ValueError('Wrong class number')
251265
elif metric.WhichOneof('metric') == 'max_f1':
252-
assert loss_type in binary_loss_set
266+
assert loss_type & binary_loss_set
253267
if num_class == 1:
254268
label = tf.to_int64(self._labels[label_name])
255269
metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
@@ -261,50 +275,50 @@ def _build_metric_impl(self,
261275
else:
262276
raise ValueError('Wrong class number')
263277
elif metric.WhichOneof('metric') == 'recall_at_topk':
264-
assert loss_type in binary_loss_set
278+
assert loss_type & binary_loss_set
265279
assert num_class > 1
266280
label = tf.to_int64(self._labels[label_name])
267281
metric_dict['recall_at_topk' + suffix] = metrics_tf.recall_at_k(
268282
label, self._prediction_dict['logits' + suffix],
269283
metric.recall_at_topk.topk)
270284
elif metric.WhichOneof('metric') == 'mean_absolute_error':
271285
label = tf.to_float(self._labels[label_name])
272-
if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
286+
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
273287
metric_dict['mean_absolute_error' +
274288
suffix] = metrics_tf.mean_absolute_error(
275289
label, self._prediction_dict['y' + suffix])
276-
elif loss_type == LossType.CLASSIFICATION and num_class == 1:
290+
elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
277291
metric_dict['mean_absolute_error' +
278292
suffix] = metrics_tf.mean_absolute_error(
279293
label, self._prediction_dict['probs' + suffix])
280294
else:
281295
assert False, 'mean_absolute_error is not supported for this model'
282296
elif metric.WhichOneof('metric') == 'mean_squared_error':
283297
label = tf.to_float(self._labels[label_name])
284-
if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
298+
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
285299
metric_dict['mean_squared_error' +
286300
suffix] = metrics_tf.mean_squared_error(
287301
label, self._prediction_dict['y' + suffix])
288-
elif num_class == 1 and loss_type in binary_loss_set:
302+
elif num_class == 1 and loss_type & binary_loss_set:
289303
metric_dict['mean_squared_error' +
290304
suffix] = metrics_tf.mean_squared_error(
291305
label, self._prediction_dict['probs' + suffix])
292306
else:
293307
assert False, 'mean_squared_error is not supported for this model'
294308
elif metric.WhichOneof('metric') == 'root_mean_squared_error':
295309
label = tf.to_float(self._labels[label_name])
296-
if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
310+
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
297311
metric_dict['root_mean_squared_error' +
298312
suffix] = metrics_tf.root_mean_squared_error(
299313
label, self._prediction_dict['y' + suffix])
300-
elif loss_type == LossType.CLASSIFICATION and num_class == 1:
314+
elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
301315
metric_dict['root_mean_squared_error' +
302316
suffix] = metrics_tf.root_mean_squared_error(
303317
label, self._prediction_dict['probs' + suffix])
304318
else:
305319
assert False, 'root_mean_squared_error is not supported for this model'
306320
elif metric.WhichOneof('metric') == 'accuracy':
307-
assert loss_type == LossType.CLASSIFICATION
321+
assert loss_type & {LossType.CLASSIFICATION}
308322
assert num_class > 1
309323
label = tf.to_int64(self._labels[label_name])
310324
metric_dict['accuracy' + suffix] = metrics_tf.accuracy(
@@ -313,20 +327,24 @@ def _build_metric_impl(self,
313327

314328
def build_metric_graph(self, eval_config):
315329
metric_dict = {}
330+
loss_types = {self._loss_type}
331+
if len(self._losses) > 0:
332+
loss_types = {loss.loss_type for loss in self._losses}
316333
for metric in eval_config.metrics_set:
317334
metric_dict.update(
318335
self._build_metric_impl(
319336
metric,
320-
loss_type=self._loss_type,
337+
loss_type=loss_types,
321338
label_name=self._label_name,
322339
num_class=self._num_class))
323340
return metric_dict
324341

325342
def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
326-
if loss_type in [
343+
binary_loss_set = {
327344
LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
328345
LossType.PAIR_WISE_LOSS
329-
]:
346+
}
347+
if loss_type in binary_loss_set:
330348
if num_class == 1:
331349
return ['probs' + suffix, 'logits' + suffix]
332350
else:
@@ -340,4 +358,11 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
340358
raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))
341359

342360
def get_outputs(self):
343-
return self._get_outputs_impl(self._loss_type, self._num_class)
361+
if len(self._losses) == 0:
362+
return self._get_outputs_impl(self._loss_type, self._num_class)
363+
364+
all_outputs = []
365+
for loss in self._losses:
366+
outputs = self._get_outputs_impl(loss.loss_type, self._num_class)
367+
all_outputs.extend(outputs)
368+
return list(set(all_outputs))

0 commit comments

Comments
 (0)