@@ -50,11 +50,6 @@ class RNNKeys(object):
5050 STATE_PREFIX = 'rnn_cell_state'
5151
5252
53- _CELL_TYPES = {'basic_rnn' : contrib_rnn .BasicRNNCell ,
54- 'lstm' : contrib_rnn .LSTMCell ,
55- 'gru' : contrib_rnn .GRUCell ,}
56-
57-
5853def _get_state_name (i ):
5954 """Constructs the name string for state component `i`."""
6055 return '{}_{}' .format (rnn_common .RNNKeys .STATE_PREFIX , i )
@@ -532,7 +527,8 @@ def _dynamic_rnn_model_fn(features, labels, mode):
532527 dropout = (dropout_keep_probabilities
533528 if mode == model_fn .ModeKeys .TRAIN
534529 else None )
535- cell = _construct_rnn_cell (cell_type , num_units , dropout )
530+ # This class promises to use the cell type selected by that function.
531+ cell = rnn_common .construct_rnn_cell (num_units , cell_type , dropout )
536532 initial_state = dict_to_state_tuple (features , cell )
537533 rnn_activations , final_state = construct_rnn (
538534 initial_state ,
@@ -589,58 +585,6 @@ def _dynamic_rnn_model_fn(features, labels, mode):
589585 return _dynamic_rnn_model_fn
590586
591587
592- def _get_single_cell (cell_type , num_units ):
593- """Constructs and return an single `RNNCell`.
594-
595- Args:
596- cell_type: Either a string identifying the `RNNCell` type, a subclass of
597- `RNNCell` or an instance of an `RNNCell`.
598- num_units: The number of units in the `RNNCell`.
599- Returns:
600- An initialized `RNNCell`.
601- Raises:
602- ValueError: `cell_type` is an invalid `RNNCell` name.
603- TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
604- """
605- if isinstance (cell_type , contrib_rnn .RNNCell ):
606- return cell_type
607- if isinstance (cell_type , str ):
608- cell_type = _CELL_TYPES .get (cell_type )
609- if cell_type is None :
610- raise ValueError ('The supported cell types are {}; got {}' .format (
611- list (_CELL_TYPES .keys ()), cell_type ))
612- if not issubclass (cell_type , contrib_rnn .RNNCell ):
613- raise TypeError (
614- 'cell_type must be a subclass of RNNCell or one of {}.' .format (
615- list (_CELL_TYPES .keys ())))
616- return cell_type (num_units = num_units )
617-
618-
619- def _construct_rnn_cell (cell_type , num_units , dropout_keep_probabilities ):
620- """Constructs cells, applies dropout and assembles a `MultiRNNCell`.
621-
622- Args:
623- cell_type: A string identifying the `RNNCell` type, a subclass of
624- `RNNCell` or an instance of an `RNNCell`.
625- num_units: A single `int` or a list/tuple of `int`s. The size of the
626- `RNNCell`s.
627- dropout_keep_probabilities: a list of dropout probabilities or `None`. If a
628- list is given, it must have length `len(cell_type) + 1`.
629-
630- Returns:
631- An initialized `RNNCell`.
632- """
633- if not isinstance (num_units , (list , tuple )):
634- num_units = (num_units ,)
635-
636- cells = [_get_single_cell (cell_type , n ) for n in num_units ]
637- if dropout_keep_probabilities :
638- cells = rnn_common .apply_dropout (cells , dropout_keep_probabilities )
639- if len (cells ) == 1 :
640- return cells [0 ]
641- return contrib_rnn .MultiRNNCell (cells )
642-
643-
644588def _get_dropout_and_num_units (cell_type ,
645589 num_units ,
646590 num_rnn_layers ,
@@ -695,6 +639,8 @@ def __init__(self,
695639 all state components or none of them. If none are included, then the default
696640 (zero) state is used as an initial state. See the documentation for
697641 `dict_to_state_tuple` and `state_tuple_to_dict` for further details.
642+ The input function can call rnn_common.construct_rnn_cell() to obtain the
643+ same cell type that this class will select from arguments to __init__.
698644
699645 The `predict()` method of the `Estimator` returns a dictionary with keys
700646 `STATE_PREFIX_i` for `0 <= i < n` where `n` is the number of nested elements
0 commit comments