@@ -226,7 +226,7 @@ def _multi_label_head(n_classes,
226226 metric_class_ids = None ):
227227 """Creates a _Head for multi label classification.
228228
229- The Head uses softmax cross entropy loss.
229+ The Head uses sigmoid cross entropy loss.
230230
231231 Args:
232232 n_classes: Integer, number of classes, must be >= 2
@@ -246,7 +246,7 @@ def _multi_label_head(n_classes,
246246 metrics. Must all be in the range `[0, n_classes)`.
247247
248248 Returns:
249- An instance of _MultiClassHead .
249+ An instance of _MultiLabelHead .
250250
251251 Raises:
252252 ValueError: if n_classes is < 2
@@ -295,6 +295,11 @@ def _weighted_loss_combiner(losses):
295295 return _MultiHead (heads , loss_combiner = _weighted_loss_combiner )
296296
297297
298+ def no_op_train_fn (loss ):
299+ del loss
300+ return control_flow_ops .no_op ()
301+
302+
298303# TODO(zakaria): Make the classes public once we are ready for users to subclass
299304# them. See b/34751732
300305class _Head (object ):
@@ -336,7 +341,8 @@ def create_model_fn_ops(self,
336341 mode: Estimator's `ModeKeys`.
337342 labels: Labels `Tensor`, or `dict` of same.
338343 train_op_fn: Function that takes a scalar loss and returns an op to
339- optimize with the loss.
344+ optimize with the loss. Must not be `None` in TRAIN mode. If you want
345+ to optimize loss yourself you can pass `no_op_train_fn`.
340346 logits: logits `Tensor`, or `dict` of same, to be used for the head.
341347 logits_input: `Tensor` from which to build logits.
342348 scope: Optional scope for `variable_scope`.
@@ -520,7 +526,9 @@ def _create_model_fn_ops(features,
520526 logging_ops .scalar_summary (
521527 _summary_key (head_name , mkey .LOSS ), weighted_average_loss )
522528
523- if (mode == model_fn .ModeKeys .TRAIN ) and (train_op_fn is not None ):
529+ if mode == model_fn .ModeKeys .TRAIN :
530+ if train_op_fn is None :
531+ raise ValueError ("train_op_fn can not be None in TRAIN mode" )
524532 train_op = _train_op (loss , labels , train_op_fn , centered_bias ,
525533 logits_dimension , loss_fn )
526534 eval_metric_ops = metrics_fn (
@@ -1206,10 +1214,6 @@ def _metrics(self, eval_loss, predictions, labels, weights):
12061214 return metrics
12071215
12081216
1209- def _noop (unused_loss ):
1210- return control_flow_ops .no_op ()
1211-
1212-
12131217class _MultiHead (_Head ):
12141218 """_Head to combine multiple _Head objects.
12151219
@@ -1266,10 +1270,15 @@ def create_model_fn_ops(self,
12661270 labels: Labels `Tensor`, or `dict` of same.
12671271 train_op_fn: Function that takes a scalar loss and returns an op to
12681272 optimize with the loss.
1269- logits: Concatenated logits of (x, 1) shape where x is the sum of
1270- `logits_dimension` of all the heads, i.e., same as `logits_dimension`
1271- of this class. This function will split the logits tensor and pass
1272- logits of proper size to each head.
1273+ logits: Concatenated logits for all heads or a dict of head name to logits
1274+ tensor. If concatenated logits, it should have (batchsize, x) shape
1275+ where x is the sum of `logits_dimension` of all the heads,
1276+ i.e., same as `logits_dimension` of this class. create_model_fn_ops
1277+ will split the logits tensor and pass logits of proper size to each
1278+ head. This is useful if we want to be agnostic about whether you
1279+ creating a single versus multihead. logits can also be a dict for
1280+ convenience where you are creating the head specific logits explicitly
1281+ and don't want to concatenate them yourself.
12731282 logits_input: tensor to build logits from.
12741283 scope: Optional scope for variable_scope. If provided, will be passed to
12751284 all heads. Most users will want to set this to `None`, so each head
@@ -1287,28 +1296,37 @@ def create_model_fn_ops(self,
12871296 if logits is None :
12881297 # Use logits_input.
12891298 for head in self ._heads :
1290- # TODO(ptucker): Do we need to let each head create its own logits?
12911299 all_model_fn_ops .append (
12921300 head .create_model_fn_ops (
12931301 features = features ,
12941302 mode = mode ,
12951303 labels = labels ,
1296- train_op_fn = _noop ,
1304+ train_op_fn = no_op_train_fn ,
12971305 logits_input = logits_input ,
12981306 scope = scope ))
12991307 else :
1300- # Split logits for each head.
1301- for head , head_logits in zip (self ._heads , self ._split_logits (logits )):
1308+ head_logits_pairs = []
1309+ if isinstance (logits , dict ):
1310+ head_logits_pairs = []
1311+ for head in self ._heads :
1312+ head_logits_pairs .append ((head , logits [head .head_name ]))
1313+ else :
1314+ # Split logits for each head.
1315+ head_logits_pairs = zip (self ._heads , self ._split_logits (logits ))
1316+
1317+ for head , head_logits in head_logits_pairs :
13021318 all_model_fn_ops .append (
13031319 head .create_model_fn_ops (
13041320 features = features ,
13051321 mode = mode ,
13061322 labels = labels ,
1307- train_op_fn = _noop ,
1323+ train_op_fn = no_op_train_fn ,
13081324 logits = head_logits ,
13091325 scope = scope ))
13101326
13111327 if mode == model_fn .ModeKeys .TRAIN :
1328+ if train_op_fn is None :
1329+ raise ValueError ("train_op_fn can not be None in TRAIN mode." )
13121330 return self ._combine_train (all_model_fn_ops , train_op_fn )
13131331 if mode == model_fn .ModeKeys .INFER :
13141332 return self ._combine_infer (all_model_fn_ops )
0 commit comments