Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Proper handling of output values in Masking layer
  • Loading branch information
tkipf committed Jul 27, 2015
commit cb763aa26bda02929d17608ec6601cfdf334c90a
8 changes: 6 additions & 2 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def get_output_mask(self, train=False):
class Masking(MaskedLayer):
"""Mask an input sequence by using a mask value to identify padding.

This layer copies the input to the output layer,
while creating an output mask in the process.
This layer copies the input to the output layer with identified padding
replaced with 0s and creates an output mask in the process.

At each timestep, if the values all equal `mask_value`,
then the corresponding mask value for the timestep is 0 (skipped),
Expand All @@ -136,6 +136,10 @@ def get_output_mask(self, train=False):
X = self.get_input(train)
return T.any(T.ones_like(X) * (1. - T.eq(X, self.mask_value)), axis=-1)

def get_output(self, train=False):
X = self.get_input(train)
return X * T.shape_padright(T.any((1. - T.eq(X, self.mask_value)), axis=-1))

def get_config(self):
return {"name": self.__class__.__name__,
"mask_value": self.mask_value}
Expand Down
13 changes: 13 additions & 0 deletions tests/auto/keras/layers/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,19 @@ def test_non_zero(self):
# This is the expected output mask, one dimension less
np.array([[1, 1, 1, 0], [1, 1, 1, 1]])))

def test_non_zero_output(self):
"""Test output of masking layer with non-zero mask value"""
layer = core.Masking(5)
func = theano.function([layer.input], layer.get_output())
self.assertTrue(np.all(
# get output for this input, replace padding with 0
func(np.array(
[[[1, 1], [2, 1], [3, 1], [5, 5]],
[[1, 5], [5, 0], [0, 0], [0, 0]]], dtype=np.int32)) ==
# This is the expected output
np.array([[[1, 1], [2, 1], [3, 1], [0, 0]],
[[1, 5], [5, 0], [0, 0], [0, 0]]])))


if __name__ == '__main__':
unittest.main()