44
55import tensorflow as tf
66
7+ from easy_rec .python .loss .focal_loss import sigmoid_focal_loss_with_logits
8+ from easy_rec .python .loss .jrc_loss import jrc_loss
9+ from easy_rec .python .loss .pairwise_loss import pairwise_focal_loss
10+ from easy_rec .python .loss .pairwise_loss import pairwise_logistic_loss
711from easy_rec .python .loss .pairwise_loss import pairwise_loss
812from easy_rec .python .protos .loss_pb2 import LossType
913
@@ -20,6 +24,7 @@ def build(loss_type,
2024 num_class = 1 ,
2125 loss_param = None ,
2226 ** kwargs ):
27+ loss_name = kwargs .pop ('loss_name' ) if 'loss_name' in kwargs else 'unknown'
2328 if loss_type == LossType .CLASSIFICATION :
2429 if num_class == 1 :
2530 return tf .losses .sigmoid_cross_entropy (
@@ -35,8 +40,60 @@ def build(loss_type,
3540 logging .info ('%s is used' % LossType .Name (loss_type ))
3641 return tf .losses .mean_squared_error (
3742 labels = label , predictions = pred , weights = loss_weight , ** kwargs )
43+ elif loss_type == LossType .JRC_LOSS :
44+ alpha = 0.5 if loss_param is None else loss_param .alpha
45+ auto_weight = False if loss_param is None else not loss_param .HasField (
46+ 'alpha' )
47+ session = kwargs .get ('session_ids' , None )
48+ return jrc_loss (
49+ label , pred , session , alpha , auto_weight = auto_weight , name = loss_name )
3850 elif loss_type == LossType .PAIR_WISE_LOSS :
39- return pairwise_loss (label , pred )
51+ session = kwargs .get ('session_ids' , None )
52+ margin = 0 if loss_param is None else loss_param .margin
53+ temp = 1.0 if loss_param is None else loss_param .temperature
54+ return pairwise_loss (
55+ label ,
56+ pred ,
57+ session_ids = session ,
58+ margin = margin ,
59+ temperature = temp ,
60+ weights = loss_weight ,
61+ name = loss_name )
62+ elif loss_type == LossType .PAIRWISE_LOGISTIC_LOSS :
63+ session = kwargs .get ('session_ids' , None )
64+ temp = 1.0 if loss_param is None else loss_param .temperature
65+ ohem_ratio = 1.0 if loss_param is None else loss_param .ohem_ratio
66+ hinge_margin = None
67+ if loss_param is not None and loss_param .HasField ('hinge_margin' ):
68+ hinge_margin = loss_param .hinge_margin
69+ return pairwise_logistic_loss (
70+ label ,
71+ pred ,
72+ session_ids = session ,
73+ temperature = temp ,
74+ hinge_margin = hinge_margin ,
75+ ohem_ratio = ohem_ratio ,
76+ weights = loss_weight ,
77+ name = loss_name )
78+ elif loss_type == LossType .PAIRWISE_FOCAL_LOSS :
79+ session = kwargs .get ('session_ids' , None )
80+ if loss_param is None :
81+ return pairwise_focal_loss (
82+ label , pred , session_ids = session , weights = loss_weight , name = loss_name )
83+ hinge_margin = None
84+ if loss_param .HasField ('hinge_margin' ):
85+ hinge_margin = loss_param .hinge_margin
86+ return pairwise_focal_loss (
87+ label ,
88+ pred ,
89+ session_ids = session ,
90+ gamma = loss_param .gamma ,
91+ alpha = loss_param .alpha if loss_param .HasField ('alpha' ) else None ,
92+ hinge_margin = hinge_margin ,
93+ ohem_ratio = loss_param .ohem_ratio ,
94+ temperature = loss_param .temperature ,
95+ weights = loss_weight ,
96+ name = loss_name )
4097 elif loss_type == LossType .F1_REWEIGHTED_LOSS :
4198 f1_beta_square = 1.0 if loss_param is None else loss_param .f1_beta_square
4299 label_smoothing = 0 if loss_param is None else loss_param .label_smoothing
@@ -46,6 +103,23 @@ def build(loss_type,
46103 f1_beta_square ,
47104 weights = loss_weight ,
48105 label_smoothing = label_smoothing )
106+ elif loss_type == LossType .BINARY_FOCAL_LOSS :
107+ if loss_param is None :
108+ return sigmoid_focal_loss_with_logits (
109+ label , pred , sample_weights = loss_weight , name = loss_name )
110+ gamma = loss_param .gamma
111+ alpha = None
112+ if loss_param .HasField ('alpha' ):
113+ alpha = loss_param .alpha
114+ return sigmoid_focal_loss_with_logits (
115+ label ,
116+ pred ,
117+ gamma = gamma ,
118+ alpha = alpha ,
119+ ohem_ratio = loss_param .ohem_ratio ,
120+ sample_weights = loss_weight ,
121+ label_smoothing = loss_param .label_smoothing ,
122+ name = loss_name )
49123 else :
50124 raise ValueError ('unsupported loss type: %s' % LossType .Name (loss_type ))
51125
0 commit comments