@@ -66,9 +66,15 @@ def _output_to_prediction_impl(self,
6666 return prediction_dict
6767
6868 def _add_to_prediction_dict (self , output ):
69- prediction_dict = self ._output_to_prediction_impl (
70- output , loss_type = self ._loss_type , num_class = self ._num_class )
71- self ._prediction_dict .update (prediction_dict )
69+ if len (self ._losses ) == 0 :
70+ prediction_dict = self ._output_to_prediction_impl (
71+ output , loss_type = self ._loss_type , num_class = self ._num_class )
72+ self ._prediction_dict .update (prediction_dict )
73+ else :
74+ for loss in self ._losses :
75+ prediction_dict = self ._output_to_prediction_impl (
76+ output , loss_type = loss .loss_type , num_class = self ._num_class )
77+ self ._prediction_dict .update (prediction_dict )
7278
7379 def build_rtp_output_dict (self ):
7480 """Forward tensor as `rank_predict`, which is a special node for RTP."""
@@ -85,15 +91,22 @@ def build_rtp_output_dict(self):
8591 rank_predict = op .outputs [0 ]
8692 except KeyError :
8793 forwarded = None
88- if self ._loss_type == LossType .CLASSIFICATION :
94+ loss_types = {self ._loss_type }
95+ if len (self ._losses ) > 0 :
96+ loss_types = {loss .loss_type for loss in self ._losses }
97+ binary_loss_set = {
98+ LossType .CLASSIFICATION , LossType .F1_REWEIGHTED_LOSS ,
99+ LossType .PAIR_WISE_LOSS
100+ }
101+ if loss_types & binary_loss_set :
89102 if 'probs' in self ._prediction_dict :
90103 forwarded = self ._prediction_dict ['probs' ]
91104 else :
92105 raise ValueError (
93106 'failed to build RTP rank_predict output: classification model ' +
94107 "expect 'probs' prediction, which is not found. Please check if" +
95108 ' build_predict_graph() is called.' )
96- elif self . _loss_type in [ LossType .L2_LOSS , LossType .SIGMOID_L2_LOSS ] :
109+ elif loss_types & { LossType .L2_LOSS , LossType .SIGMOID_L2_LOSS } :
97110 if 'y' in self ._prediction_dict :
98111 forwarded = self ._prediction_dict ['y' ]
99112 else :
@@ -105,7 +118,7 @@ def build_rtp_output_dict(self):
105118 else :
106119 logging .warning (
107120 'failed to build RTP rank_predict: unsupported loss type {}' .foramt (
108- self . _loss_type ))
121+ loss_types ))
109122 if forwarded is not None :
110123 rank_predict = tf .identity (forwarded , name = 'rank_predict' )
111124 if rank_predict is not None :
@@ -183,6 +196,8 @@ def _build_metric_impl(self,
183196 label_name ,
184197 num_class = 1 ,
185198 suffix = '' ):
199+ if not isinstance (loss_type , set ):
200+ loss_type = {loss_type }
186201 from easy_rec .python .core .easyrec_metrics import metrics_tf
187202 from easy_rec .python .core import metrics as metrics_lib
188203 binary_loss_set = {
@@ -191,8 +206,7 @@ def _build_metric_impl(self,
191206 }
192207 metric_dict = {}
193208 if metric .WhichOneof ('metric' ) == 'auc' :
194- assert loss_type in binary_loss_set
195-
209+ assert loss_type & binary_loss_set
196210 if num_class == 1 :
197211 label = tf .to_int64 (self ._labels [label_name ])
198212 metric_dict ['auc' + suffix ] = metrics_tf .auc (
@@ -208,7 +222,7 @@ def _build_metric_impl(self,
208222 else :
209223 raise ValueError ('Wrong class number' )
210224 elif metric .WhichOneof ('metric' ) == 'gauc' :
211- assert loss_type in binary_loss_set
225+ assert loss_type & binary_loss_set
212226 if num_class == 1 :
213227 label = tf .to_int64 (self ._labels [label_name ])
214228 uids = self ._feature_dict [metric .gauc .uid_field ]
@@ -231,7 +245,7 @@ def _build_metric_impl(self,
231245 else :
232246 raise ValueError ('Wrong class number' )
233247 elif metric .WhichOneof ('metric' ) == 'session_auc' :
234- assert loss_type in binary_loss_set
248+ assert loss_type & binary_loss_set
235249 if num_class == 1 :
236250 label = tf .to_int64 (self ._labels [label_name ])
237251 metric_dict ['session_auc' + suffix ] = metrics_lib .session_auc (
@@ -249,7 +263,7 @@ def _build_metric_impl(self,
249263 else :
250264 raise ValueError ('Wrong class number' )
251265 elif metric .WhichOneof ('metric' ) == 'max_f1' :
252- assert loss_type in binary_loss_set
266+ assert loss_type & binary_loss_set
253267 if num_class == 1 :
254268 label = tf .to_int64 (self ._labels [label_name ])
255269 metric_dict ['max_f1' + suffix ] = metrics_lib .max_f1 (
@@ -261,50 +275,50 @@ def _build_metric_impl(self,
261275 else :
262276 raise ValueError ('Wrong class number' )
263277 elif metric .WhichOneof ('metric' ) == 'recall_at_topk' :
264- assert loss_type in binary_loss_set
278+ assert loss_type & binary_loss_set
265279 assert num_class > 1
266280 label = tf .to_int64 (self ._labels [label_name ])
267281 metric_dict ['recall_at_topk' + suffix ] = metrics_tf .recall_at_k (
268282 label , self ._prediction_dict ['logits' + suffix ],
269283 metric .recall_at_topk .topk )
270284 elif metric .WhichOneof ('metric' ) == 'mean_absolute_error' :
271285 label = tf .to_float (self ._labels [label_name ])
272- if loss_type in [ LossType .L2_LOSS , LossType .SIGMOID_L2_LOSS ] :
286+ if loss_type & { LossType .L2_LOSS , LossType .SIGMOID_L2_LOSS } :
273287 metric_dict ['mean_absolute_error' +
274288 suffix ] = metrics_tf .mean_absolute_error (
275289 label , self ._prediction_dict ['y' + suffix ])
276- elif loss_type == LossType .CLASSIFICATION and num_class == 1 :
290+ elif loss_type & { LossType .CLASSIFICATION } and num_class == 1 :
277291 metric_dict ['mean_absolute_error' +
278292 suffix ] = metrics_tf .mean_absolute_error (
279293 label , self ._prediction_dict ['probs' + suffix ])
280294 else :
281295 assert False , 'mean_absolute_error is not supported for this model'
282296 elif metric .WhichOneof ('metric' ) == 'mean_squared_error' :
283297 label = tf .to_float (self ._labels [label_name ])
284- if loss_type in [ LossType .L2_LOSS , LossType .SIGMOID_L2_LOSS ] :
298+ if loss_type & { LossType .L2_LOSS , LossType .SIGMOID_L2_LOSS } :
285299 metric_dict ['mean_squared_error' +
286300 suffix ] = metrics_tf .mean_squared_error (
287301 label , self ._prediction_dict ['y' + suffix ])
288- elif num_class == 1 and loss_type in binary_loss_set :
302+ elif num_class == 1 and loss_type & binary_loss_set :
289303 metric_dict ['mean_squared_error' +
290304 suffix ] = metrics_tf .mean_squared_error (
291305 label , self ._prediction_dict ['probs' + suffix ])
292306 else :
293307 assert False , 'mean_squared_error is not supported for this model'
294308 elif metric .WhichOneof ('metric' ) == 'root_mean_squared_error' :
295309 label = tf .to_float (self ._labels [label_name ])
296- if loss_type in [ LossType .L2_LOSS , LossType .SIGMOID_L2_LOSS ] :
310+ if loss_type & { LossType .L2_LOSS , LossType .SIGMOID_L2_LOSS } :
297311 metric_dict ['root_mean_squared_error' +
298312 suffix ] = metrics_tf .root_mean_squared_error (
299313 label , self ._prediction_dict ['y' + suffix ])
300- elif loss_type == LossType .CLASSIFICATION and num_class == 1 :
314+ elif loss_type & { LossType .CLASSIFICATION } and num_class == 1 :
301315 metric_dict ['root_mean_squared_error' +
302316 suffix ] = metrics_tf .root_mean_squared_error (
303317 label , self ._prediction_dict ['probs' + suffix ])
304318 else :
305319 assert False , 'root_mean_squared_error is not supported for this model'
306320 elif metric .WhichOneof ('metric' ) == 'accuracy' :
307- assert loss_type == LossType .CLASSIFICATION
321+ assert loss_type & { LossType .CLASSIFICATION }
308322 assert num_class > 1
309323 label = tf .to_int64 (self ._labels [label_name ])
310324 metric_dict ['accuracy' + suffix ] = metrics_tf .accuracy (
@@ -313,20 +327,24 @@ def _build_metric_impl(self,
313327
314328 def build_metric_graph (self , eval_config ):
315329 metric_dict = {}
330+ loss_types = {self ._loss_type }
331+ if len (self ._losses ) > 0 :
332+ loss_types = {loss .loss_type for loss in self ._losses }
316333 for metric in eval_config .metrics_set :
317334 metric_dict .update (
318335 self ._build_metric_impl (
319336 metric ,
320- loss_type = self . _loss_type ,
337+ loss_type = loss_types ,
321338 label_name = self ._label_name ,
322339 num_class = self ._num_class ))
323340 return metric_dict
324341
325342 def _get_outputs_impl (self , loss_type , num_class = 1 , suffix = '' ):
326- if loss_type in [
343+ binary_loss_set = {
327344 LossType .CLASSIFICATION , LossType .F1_REWEIGHTED_LOSS ,
328345 LossType .PAIR_WISE_LOSS
329- ]:
346+ }
347+ if loss_type in binary_loss_set :
330348 if num_class == 1 :
331349 return ['probs' + suffix , 'logits' + suffix ]
332350 else :
@@ -340,4 +358,11 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
340358 raise ValueError ('invalid loss type: %s' % LossType .Name (loss_type ))
341359
342360 def get_outputs (self ):
343- return self ._get_outputs_impl (self ._loss_type , self ._num_class )
361+ if len (self ._losses ) == 0 :
362+ return self ._get_outputs_impl (self ._loss_type , self ._num_class )
363+
364+ all_outputs = []
365+ for loss in self ._losses :
366+ outputs = self ._get_outputs_impl (loss .loss_type , self ._num_class )
367+ all_outputs .extend (outputs )
368+ return list (set (all_outputs ))
0 commit comments