diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index b0979404f011..da1b8e420edf 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -1156,17 +1156,20 @@ def permute_dimensions(x, pattern): return C.transpose(x, axis) -def resize_images(x, height_factor, width_factor, data_format): - if data_format == 'channels_first': - output = repeat_elements(x, height_factor, axis=2) - output = repeat_elements(output, width_factor, axis=3) - return output - elif data_format == 'channels_last': - output = repeat_elements(x, height_factor, axis=1) - output = repeat_elements(output, width_factor, axis=2) - return output +def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): + if interpolation == 'nearest': + if data_format == 'channels_first': + output = repeat_elements(x, height_factor, axis=2) + output = repeat_elements(output, width_factor, axis=3) + return output + elif data_format == 'channels_last': + output = repeat_elements(x, height_factor, axis=1) + output = repeat_elements(output, width_factor, axis=2) + return output + else: + raise ValueError('CNTK Backend: Invalid data_format: %s' % data_format) else: - raise ValueError('CNTK Backend: Invalid data_format:', data_format) + raise NotImplementedError('CNTK only supports `nearest` interpolation.') def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): @@ -1181,7 +1184,7 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): output = repeat_elements(output, width_factor, axis=3) return output else: - raise ValueError('CNTK Backend: Invalid data_format:', data_format) + raise ValueError('CNTK Backend: Invalid data_format: %s' % data_format) def repeat_elements(x, rep, axis): diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 21d15963a36d..12fe13525af8 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1962,7 +1962,11 @@ def permute_dimensions(x, pattern): return tf.transpose(x, perm=pattern) -def resize_images(x, height_factor, width_factor, data_format): +def resize_images(x, + height_factor, + width_factor, + data_format, + interpolation='nearest'): """Resizes the images contained in a 4D tensor. # Arguments @@ -1970,6 +1974,7 @@ def resize_images(x, height_factor, width_factor, data_format): height_factor: Positive integer. width_factor: Positive integer. data_format: string, `"channels_last"` or `"channels_first"`. + interpolation: A string, one of `nearest` or `bilinear`. # Returns A tensor. @@ -1978,25 +1983,39 @@ def resize_images(x, height_factor, width_factor, data_format): ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`. """ if data_format == 'channels_first': - original_shape = int_shape(x) - new_shape = tf.shape(x)[2:] - new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) + rows, cols = 2, 3 + else: + rows, cols = 1, 2 + + original_shape = int_shape(x) + new_shape = tf.shape(x)[rows:cols + 1] + new_shape *= tf.constant(np.array([height_factor, width_factor], dtype='int32')) + + if data_format == 'channels_first': x = permute_dimensions(x, [0, 2, 3, 1]) + if interpolation == 'nearest': x = tf.image.resize_nearest_neighbor(x, new_shape) + elif interpolation == 'bilinear': + x = tf.image.resize_bilinear(x, new_shape) + else: + raise ValueError('interpolation should be one ' + 'of "nearest" or "bilinear".') + if data_format == 'channels_first': x = permute_dimensions(x, [0, 3, 1, 2]) - x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None, - original_shape[3] * width_factor if original_shape[3] is not None else None)) - return x - elif data_format == 'channels_last': - original_shape = int_shape(x) - new_shape = tf.shape(x)[1:3] - new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) - x = tf.image.resize_nearest_neighbor(x, new_shape) - x.set_shape((None, original_shape[1] * height_factor if original_shape[1] is not None else None, - original_shape[2] * width_factor if original_shape[2] is not None else None, None)) - return x + + if original_shape[rows] is None: + new_height = None else: - raise ValueError('Unknown data_format: ' + str(data_format)) + new_height = original_shape[rows] * height_factor + + if original_shape[cols] is None: + new_width = None + else: + new_width = original_shape[cols] * width_factor + + output_shape = (None, new_height, new_width, None) + x.set_shape(transpose_shape(output_shape, data_format, spatial_axes=(1, 2))) + return x def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 1b21694696b5..cdb4a681c857 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -965,7 +965,11 @@ def repeat_elements(x, rep, axis): return y -def resize_images(x, height_factor, width_factor, data_format): +def resize_images(x, + height_factor, + width_factor, + data_format, + interpolation='nearest'): """Resize the images contained in a 4D tensor of shape - [batch, channels, height, width] (for 'channels_first' data_format) - [batch, height, width, channels] (for 'channels_last' data_format) @@ -973,16 +977,40 @@ def resize_images(x, height_factor, width_factor, data_format): positive integers. """ if data_format == 'channels_first': - output = repeat_elements(x, height_factor, axis=2) - output = repeat_elements(output, width_factor, axis=3) - return output + axis_1 = 2 + axis_2 = 3 elif data_format == 'channels_last': - output = repeat_elements(x, height_factor, axis=1) - output = repeat_elements(output, width_factor, axis=2) - return output + axis_1 = 1 + axis_2 = 2 else: raise ValueError('Invalid data_format:', data_format) + if interpolation == 'nearest': + output = repeat_elements(x, height_factor, axis=axis_1) + output = repeat_elements(output, width_factor, axis=axis_2) + elif interpolation == 'bilinear': + if not (height_factor == width_factor == 2): + raise NotImplementedError( + 'Bilinear upscaling with factors other than (2, 2)' + 'is not available when using the Theano backend.') + if data_format == 'channels_last': + output = permute_dimensions(x, [0, 3, 1, 2]) + else: + output = x + output = T.nnet.abstract_conv.bilinear_upsampling(output, + ratio=height_factor) + if data_format == 'channels_last': + output = permute_dimensions(output, [0, 2, 3, 1]) + if hasattr(x, '_keras_shape'): + output._keras_shape = list(x._keras_shape) + output._keras_shape[axis_1] *= height_factor + output._keras_shape[axis_2] *= width_factor + output._keras_shape = tuple(output._keras_shape) + else: + raise ValueError('interpolation should be one of "nearest" or "bilinear".') + + return output + def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): """Resize the volume contained in a 5D tensor of shape diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index 2c41bd8a04e5..b89b7618f0e3 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -1916,6 +1916,9 @@ class UpSampling2D(Layer): It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". + interpolation: A string, one of `nearest` or `bilinear`. + Note that CNTK does not support yet the `bilinear` upscaling + and that with Theano, only `size=(2, 2)` is possible. # Input shape 4D tensor with shape: @@ -1933,11 +1936,15 @@ class UpSampling2D(Layer): """ @interfaces.legacy_upsampling2d_support - def __init__(self, size=(2, 2), data_format=None, **kwargs): + def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs): super(UpSampling2D, self).__init__(**kwargs) self.data_format = K.normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, 'size') self.input_spec = InputSpec(ndim=4) + if interpolation not in ['nearest', 'bilinear']: + raise ValueError('interpolation should be one ' + 'of "nearest" or "bilinear".') + self.interpolation = interpolation def compute_output_shape(self, input_shape): if self.data_format == 'channels_first': @@ -1957,7 +1964,7 @@ def compute_output_shape(self, input_shape): def call(self, inputs): return K.resize_images(inputs, self.size[0], self.size[1], - self.data_format) + self.data_format, self.interpolation) def get_config(self): config = {'size': self.size, diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index 726207110bc0..6a8395ff8ee9 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -1323,6 +1323,23 @@ def test_resize_images(self): K.resize_images(K.variable(xval), 2, 2, data_format='channels_middle') + @staticmethod + def _helper_bilinear(data_format, height_factor, width_factor): + x_shape = (2, 3, 4, 5) + check_single_tensor_operation('resize_images', x_shape, + [KTF, KTH], + height_factor=height_factor, + width_factor=width_factor, + data_format=data_format, + interpolation='bilinear') + + @pytest.mark.skipif(K.backend() == 'cntk', reason='Not supported.') + @pytest.mark.parametrize('data_format', ['channels_first', 'channels_last']) + def test_resize_images_bilinear(self, data_format): + self._helper_bilinear(data_format, 2, 2) + with pytest.raises(NotImplementedError): + self._helper_bilinear(data_format, 4, 4) + def test_resize_volumes(self): for data_format in ['channels_first', 'channels_last']: shape = (5, 5, 5) diff --git a/tests/keras/layers/convolutional_test.py b/tests/keras/layers/convolutional_test.py index 314823f95f43..df3862a8e0a3 100644 --- a/tests/keras/layers/convolutional_test.py +++ b/tests/keras/layers/convolutional_test.py @@ -862,6 +862,46 @@ def test_upsampling_2d(): assert_allclose(np_output, expected_out) +@pytest.mark.skipif((K.backend() == 'cntk'), + reason='cntk does not support it yet') +@pytest.mark.parametrize('data_format', + ['channels_first', 'channels_last']) +def test_upsampling_2d_bilinear(data_format): + num_samples = 2 + stack_size = 2 + input_num_row = 11 + input_num_col = 12 + + if data_format == 'channels_first': + inputs = np.random.rand(num_samples, stack_size, input_num_row, + input_num_col) + else: # tf + inputs = np.random.rand(num_samples, input_num_row, input_num_col, + stack_size) + + # basic test + layer_test(convolutional.UpSampling2D, + kwargs={'size': (2, 2), + 'data_format': data_format, + 'interpolation': 'bilinear'}, + input_shape=inputs.shape) + + for length_row in [2]: + for length_col in [2, 3]: + layer = convolutional.UpSampling2D( + size=(length_row, length_col), + data_format=data_format) + layer.build(inputs.shape) + outputs = layer(K.variable(inputs)) + np_output = K.eval(outputs) + if data_format == 'channels_first': + assert np_output.shape[2] == length_row * input_num_row + assert np_output.shape[3] == length_col * input_num_col + else: # tf + assert np_output.shape[1] == length_row * input_num_row + assert np_output.shape[2] == length_col * input_num_col + + @pytest.mark.skipif((K.backend() == 'cntk'), reason="cntk does not support it yet") def test_upsampling_3d():