Skip to content

Commit 1617ffb

Browse files
Move dynamic_rnn_estimator._construct_rnn_cell() into rnn_common and make it
public. Change: 148875698
1 parent ec86b03 commit 1617ffb

File tree

2 files changed

+65
-58
lines changed

2 files changed

+65
-58
lines changed

tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py

Lines changed: 4 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@ class RNNKeys(object):
5050
STATE_PREFIX = 'rnn_cell_state'
5151

5252

53-
_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell,
54-
'lstm': contrib_rnn.LSTMCell,
55-
'gru': contrib_rnn.GRUCell,}
56-
57-
5853
def _get_state_name(i):
5954
"""Constructs the name string for state component `i`."""
6055
return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
@@ -532,7 +527,8 @@ def _dynamic_rnn_model_fn(features, labels, mode):
532527
dropout = (dropout_keep_probabilities
533528
if mode == model_fn.ModeKeys.TRAIN
534529
else None)
535-
cell = _construct_rnn_cell(cell_type, num_units, dropout)
530+
# This class promises to use the cell type selected by that function.
531+
cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout)
536532
initial_state = dict_to_state_tuple(features, cell)
537533
rnn_activations, final_state = construct_rnn(
538534
initial_state,
@@ -589,58 +585,6 @@ def _dynamic_rnn_model_fn(features, labels, mode):
589585
return _dynamic_rnn_model_fn
590586

591587

592-
def _get_single_cell(cell_type, num_units):
593-
"""Constructs and return an single `RNNCell`.
594-
595-
Args:
596-
cell_type: Either a string identifying the `RNNCell` type, a subclass of
597-
`RNNCell` or an instance of an `RNNCell`.
598-
num_units: The number of units in the `RNNCell`.
599-
Returns:
600-
An initialized `RNNCell`.
601-
Raises:
602-
ValueError: `cell_type` is an invalid `RNNCell` name.
603-
TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
604-
"""
605-
if isinstance(cell_type, contrib_rnn.RNNCell):
606-
return cell_type
607-
if isinstance(cell_type, str):
608-
cell_type = _CELL_TYPES.get(cell_type)
609-
if cell_type is None:
610-
raise ValueError('The supported cell types are {}; got {}'.format(
611-
list(_CELL_TYPES.keys()), cell_type))
612-
if not issubclass(cell_type, contrib_rnn.RNNCell):
613-
raise TypeError(
614-
'cell_type must be a subclass of RNNCell or one of {}.'.format(
615-
list(_CELL_TYPES.keys())))
616-
return cell_type(num_units=num_units)
617-
618-
619-
def _construct_rnn_cell(cell_type, num_units, dropout_keep_probabilities):
620-
"""Constructs cells, applies dropout and assembles a `MultiRNNCell`.
621-
622-
Args:
623-
cell_type: A string identifying the `RNNCell` type, a subclass of
624-
`RNNCell` or an instance of an `RNNCell`.
625-
num_units: A single `int` or a list/tuple of `int`s. The size of the
626-
`RNNCell`s.
627-
dropout_keep_probabilities: a list of dropout probabilities or `None`. If a
628-
list is given, it must have length `len(cell_type) + 1`.
629-
630-
Returns:
631-
An initialized `RNNCell`.
632-
"""
633-
if not isinstance(num_units, (list, tuple)):
634-
num_units = (num_units,)
635-
636-
cells = [_get_single_cell(cell_type, n) for n in num_units]
637-
if dropout_keep_probabilities:
638-
cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities)
639-
if len(cells) == 1:
640-
return cells[0]
641-
return contrib_rnn.MultiRNNCell(cells)
642-
643-
644588
def _get_dropout_and_num_units(cell_type,
645589
num_units,
646590
num_rnn_layers,
@@ -695,6 +639,8 @@ def __init__(self,
695639
all state components or none of them. If none are included, then the default
696640
(zero) state is used as an initial state. See the documentation for
697641
`dict_to_state_tuple` and `state_tuple_to_dict` for further details.
642+
The input function can call rnn_common.construct_rnn_cell() to obtain the
643+
same cell type that this class will select from arguments to __init__.
698644
699645
The `predict()` method of the `Estimator` returns a dictionary with keys
700646
`STATE_PREFIX_i` for `0 <= i < n` where `n` is the number of nested elements

tensorflow/contrib/learn/python/learn/estimators/rnn_common.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,67 @@ class RNNKeys(object):
4040
STATE_PREFIX = 'rnn_cell_state'
4141

4242

43+
_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell,
44+
'lstm': contrib_rnn.LSTMCell,
45+
'gru': contrib_rnn.GRUCell,}
46+
47+
48+
def _get_single_cell(cell_type, num_units):
49+
"""Constructs and return a single `RNNCell`.
50+
51+
Args:
52+
cell_type: Either a string identifying the `RNNCell` type, a subclass of
53+
`RNNCell` or an instance of an `RNNCell`.
54+
num_units: The number of units in the `RNNCell`.
55+
Returns:
56+
An initialized `RNNCell`.
57+
Raises:
58+
ValueError: `cell_type` is an invalid `RNNCell` name.
59+
TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
60+
"""
61+
if isinstance(cell_type, contrib_rnn.RNNCell):
62+
return cell_type
63+
if isinstance(cell_type, str):
64+
cell_type = _CELL_TYPES.get(cell_type)
65+
if cell_type is None:
66+
raise ValueError('The supported cell types are {}; got {}'.format(
67+
list(_CELL_TYPES.keys()), cell_type))
68+
if not issubclass(cell_type, contrib_rnn.RNNCell):
69+
raise TypeError(
70+
'cell_type must be a subclass of RNNCell or one of {}.'.format(
71+
list(_CELL_TYPES.keys())))
72+
return cell_type(num_units=num_units)
73+
74+
75+
def construct_rnn_cell(num_units, cell_type='basic_rnn',
76+
dropout_keep_probabilities=None):
77+
"""Constructs cells, applies dropout and assembles a `MultiRNNCell`.
78+
79+
The cell type chosen by DynamicRNNEstimator.__init__() is the same as
80+
returned by this function when called with the same arguments.
81+
82+
Args:
83+
num_units: A single `int` or a list/tuple of `int`s. The size of the
84+
`RNNCell`s.
85+
cell_type: A string identifying the `RNNCell` type, a subclass of
86+
`RNNCell` or an instance of an `RNNCell`.
87+
dropout_keep_probabilities: a list of dropout probabilities or `None`. If a
88+
list is given, it must have length `len(cell_type) + 1`.
89+
90+
Returns:
91+
An initialized `RNNCell`.
92+
"""
93+
if not isinstance(num_units, (list, tuple)):
94+
num_units = (num_units,)
95+
96+
cells = [_get_single_cell(cell_type, n) for n in num_units]
97+
if dropout_keep_probabilities:
98+
cells = apply_dropout(cells, dropout_keep_probabilities)
99+
if len(cells) == 1:
100+
return cells[0]
101+
return contrib_rnn.MultiRNNCell(cells)
102+
103+
43104
def apply_dropout(cells, dropout_keep_probabilities, random_seed=None):
44105
"""Applies dropout to the outputs and inputs of `cell`.
45106

0 commit comments

Comments
 (0)