Skip to content

Commit 990e24a

Browse files
Zakaria Haquetensorflower-gardener
authored andcommitted
Handles None train_op_fn properly. Makes it required for TRAIN mode. Provides a util function to create noon train op for cases when users want to handle optimization in model func.
Also, Allows a dict of logits for multihead. See attached bug for more details. Change: 148913307
1 parent ec204fc commit 990e24a

File tree

2 files changed

+143
-74
lines changed

2 files changed

+143
-74
lines changed

tensorflow/contrib/learn/python/learn/estimators/head.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _multi_label_head(n_classes,
226226
metric_class_ids=None):
227227
"""Creates a _Head for multi label classification.
228228
229-
The Head uses softmax cross entropy loss.
229+
The Head uses sigmoid cross entropy loss.
230230
231231
Args:
232232
n_classes: Integer, number of classes, must be >= 2
@@ -246,7 +246,7 @@ def _multi_label_head(n_classes,
246246
metrics. Must all be in the range `[0, n_classes)`.
247247
248248
Returns:
249-
An instance of _MultiClassHead.
249+
An instance of _MultiLabelHead.
250250
251251
Raises:
252252
ValueError: if n_classes is < 2
@@ -295,6 +295,11 @@ def _weighted_loss_combiner(losses):
295295
return _MultiHead(heads, loss_combiner=_weighted_loss_combiner)
296296

297297

298+
def no_op_train_fn(loss):
299+
del loss
300+
return control_flow_ops.no_op()
301+
302+
298303
# TODO(zakaria): Make the classes public once we are ready for users to subclass
299304
# them. See b/34751732
300305
class _Head(object):
@@ -336,7 +341,8 @@ def create_model_fn_ops(self,
336341
mode: Estimator's `ModeKeys`.
337342
labels: Labels `Tensor`, or `dict` of same.
338343
train_op_fn: Function that takes a scalar loss and returns an op to
339-
optimize with the loss.
344+
optimize with the loss. Must not be `None` in TRAIN mode. If you want
345+
to optimize loss yourself you can pass `no_op_train_fn`.
340346
logits: logits `Tensor`, or `dict` of same, to be used for the head.
341347
logits_input: `Tensor` from which to build logits.
342348
scope: Optional scope for `variable_scope`.
@@ -520,7 +526,9 @@ def _create_model_fn_ops(features,
520526
logging_ops.scalar_summary(
521527
_summary_key(head_name, mkey.LOSS), weighted_average_loss)
522528

523-
if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
529+
if mode == model_fn.ModeKeys.TRAIN:
530+
if train_op_fn is None:
531+
raise ValueError("train_op_fn can not be None in TRAIN mode")
524532
train_op = _train_op(loss, labels, train_op_fn, centered_bias,
525533
logits_dimension, loss_fn)
526534
eval_metric_ops = metrics_fn(
@@ -1206,10 +1214,6 @@ def _metrics(self, eval_loss, predictions, labels, weights):
12061214
return metrics
12071215

12081216

1209-
def _noop(unused_loss):
1210-
return control_flow_ops.no_op()
1211-
1212-
12131217
class _MultiHead(_Head):
12141218
"""_Head to combine multiple _Head objects.
12151219
@@ -1266,10 +1270,15 @@ def create_model_fn_ops(self,
12661270
labels: Labels `Tensor`, or `dict` of same.
12671271
train_op_fn: Function that takes a scalar loss and returns an op to
12681272
optimize with the loss.
1269-
logits: Concatenated logits of (x, 1) shape where x is the sum of
1270-
`logits_dimension` of all the heads, i.e., same as `logits_dimension`
1271-
of this class. This function will split the logits tensor and pass
1272-
logits of proper size to each head.
1273+
logits: Concatenated logits for all heads or a dict of head name to logits
1274+
tensor. If concatenated logits, it should have (batchsize, x) shape
1275+
where x is the sum of `logits_dimension` of all the heads,
1276+
i.e., same as `logits_dimension` of this class. create_model_fn_ops
1277+
will split the logits tensor and pass logits of proper size to each
1278+
head. This is useful if we want to be agnostic about whether you
1279+
creating a single versus multihead. logits can also be a dict for
1280+
convenience where you are creating the head specific logits explicitly
1281+
and don't want to concatenate them yourself.
12731282
logits_input: tensor to build logits from.
12741283
scope: Optional scope for variable_scope. If provided, will be passed to
12751284
all heads. Most users will want to set this to `None`, so each head
@@ -1287,28 +1296,37 @@ def create_model_fn_ops(self,
12871296
if logits is None:
12881297
# Use logits_input.
12891298
for head in self._heads:
1290-
# TODO(ptucker): Do we need to let each head create its own logits?
12911299
all_model_fn_ops.append(
12921300
head.create_model_fn_ops(
12931301
features=features,
12941302
mode=mode,
12951303
labels=labels,
1296-
train_op_fn=_noop,
1304+
train_op_fn=no_op_train_fn,
12971305
logits_input=logits_input,
12981306
scope=scope))
12991307
else:
1300-
# Split logits for each head.
1301-
for head, head_logits in zip(self._heads, self._split_logits(logits)):
1308+
head_logits_pairs = []
1309+
if isinstance(logits, dict):
1310+
head_logits_pairs = []
1311+
for head in self._heads:
1312+
head_logits_pairs.append((head, logits[head.head_name]))
1313+
else:
1314+
# Split logits for each head.
1315+
head_logits_pairs = zip(self._heads, self._split_logits(logits))
1316+
1317+
for head, head_logits in head_logits_pairs:
13021318
all_model_fn_ops.append(
13031319
head.create_model_fn_ops(
13041320
features=features,
13051321
mode=mode,
13061322
labels=labels,
1307-
train_op_fn=_noop,
1323+
train_op_fn=no_op_train_fn,
13081324
logits=head_logits,
13091325
scope=scope))
13101326

13111327
if mode == model_fn.ModeKeys.TRAIN:
1328+
if train_op_fn is None:
1329+
raise ValueError("train_op_fn can not be None in TRAIN mode.")
13121330
return self._combine_train(all_model_fn_ops, train_op_fn)
13131331
if mode == model_fn.ModeKeys.INFER:
13141332
return self._combine_infer(all_model_fn_ops)

0 commit comments

Comments
 (0)