From 0855745c1b23011a6be047b5c707a169dd4cff27 Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Sat, 3 Feb 2018 19:19:49 -0500 Subject: [PATCH 01/14] bilinear upsampling initial implementation --- keras/backend/cntk_backend.py | 23 ++++++++------- keras/backend/tensorflow_backend.py | 9 ++++-- keras/backend/theano_backend.py | 29 ++++++++++++++----- keras/layers/convolutional.py | 7 +++-- tests/keras/layers/convolutional_test.py | 37 ++++++++++++++++++++++++ 5 files changed, 84 insertions(+), 21 deletions(-) diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index 84ae05343a64..88b4d91492b2 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -1122,17 +1122,20 @@ def permute_dimensions(x, pattern): return C.transpose(x, axis) -def resize_images(x, height_factor, width_factor, data_format): - if data_format == 'channels_first': - output = repeat_elements(x, height_factor, axis=2) - output = repeat_elements(output, width_factor, axis=3) - return output - elif data_format == 'channels_last': - output = repeat_elements(x, height_factor, axis=1) - output = repeat_elements(output, width_factor, axis=2) - return output +def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): + if interpolation == 'nearest': + if data_format == 'channels_first': + output = repeat_elements(x, height_factor, axis=2) + output = repeat_elements(output, width_factor, axis=3) + return output + elif data_format == 'channels_last': + output = repeat_elements(x, height_factor, axis=1) + output = repeat_elements(output, width_factor, axis=2) + return output + else: + raise ValueError('CNTK Backend: Invalid data_format:', data_format) else: - raise ValueError('CNTK Backend: Invalid data_format:', data_format) + raise NotImplementedError def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index ef6e6c353956..ac2f88b859a8 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1905,7 +1905,7 @@ def permute_dimensions(x, pattern): return tf.transpose(x, perm=pattern) -def resize_images(x, height_factor, width_factor, data_format): +def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): """Resizes the images contained in a 4D tensor. # Arguments @@ -1925,7 +1925,12 @@ def resize_images(x, height_factor, width_factor, data_format): new_shape = tf.shape(x)[2:] new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) x = permute_dimensions(x, [0, 2, 3, 1]) - x = tf.image.resize_nearest_neighbor(x, new_shape) + if interpolation == 'nearest': + x = tf.image.resize_nearest_neighbor(x, new_shape) + elif interpolation == 'bilinear': + x = tf.image.resize_bilinear(x, new_shape) + else: + raise ValueError('interpolation should be one of "nearest" or "bilinear".') x = permute_dimensions(x, [0, 3, 1, 2]) x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None, original_shape[3] * width_factor if original_shape[3] is not None else None)) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 7c6a51df23c3..a4ab6f475761 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -884,7 +884,7 @@ def repeat_elements(x, rep, axis): return y -def resize_images(x, height_factor, width_factor, data_format): +def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): """Resize the images contained in a 4D tensor of shape - [batch, channels, height, width] (for 'channels_first' data_format) - [batch, height, width, channels] (for 'channels_last' data_format) @@ -892,16 +892,31 @@ def resize_images(x, height_factor, width_factor, data_format): positive integers. """ if data_format == 'channels_first': - output = repeat_elements(x, height_factor, axis=2) - output = repeat_elements(output, width_factor, axis=3) - return output + axis_1 = 2 + axis_2 = 3 elif data_format == 'channels_last': - output = repeat_elements(x, height_factor, axis=1) - output = repeat_elements(output, width_factor, axis=2) - return output + axis_1 = 1 + axis_2 = 2 else: raise ValueError('Invalid data_format:', data_format) + if interpolation == 'nearest': + output = repeat_elements(x, height_factor, axis=axis_1) + output = repeat_elements(output, width_factor, axis=axis_2) + elif interpolation == 'bilinear': + output = theano.tensor.nnet.abstract_conv.bilinear_upsampling(x, ratio=(height_factor, width_factor)) + if hasattr(x, '_keras_shape'): + output._keras_shape = list(x._keras_shape) + repeat_dim_1 = x._keras_shape[axis_1] + repeat_dim_2 = x._keras_shape[axis_2] + output._keras_shape[axis_1] = repeat_dim_1 * rep + output._keras_shape[axis_2] = repeat_dim_2 * rep + output._keras_shape = tuple(output._keras_shape) + else: + raise ValueError('interpolation should be one of "nearest" or "bilinear".') + + return output + def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): """Resize the volume contained in a 5D tensor of shape diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index 1cc58859057d..121099cf73f5 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -1582,11 +1582,14 @@ class UpSampling2D(Layer): """ @interfaces.legacy_upsampling2d_support - def __init__(self, size=(2, 2), data_format=None, **kwargs): + def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs): super(UpSampling2D, self).__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, 'size') self.input_spec = InputSpec(ndim=4) + if interpolation not in ['nearest', 'bilinear']: + raise ValueError('interpolation should be one of "nearest" or "bilinear".') + self.interpolation = interpolation def compute_output_shape(self, input_shape): if self.data_format == 'channels_first': @@ -1606,7 +1609,7 @@ def compute_output_shape(self, input_shape): def call(self, inputs): return K.resize_images(inputs, self.size[0], self.size[1], - self.data_format) + self.data_format, self.interpolation) def get_config(self): config = {'size': self.size, diff --git a/tests/keras/layers/convolutional_test.py b/tests/keras/layers/convolutional_test.py index 4bb3fe58430a..1ee44f54f39f 100644 --- a/tests/keras/layers/convolutional_test.py +++ b/tests/keras/layers/convolutional_test.py @@ -734,6 +734,43 @@ def test_upsampling_2d(): assert_allclose(np_output, expected_out) +@pytest.mark.skipif((K.backend() == 'cntk'), + reason="cntk does not support it yet") +def test_upsampling_2d_bilinear(): + num_samples = 2 + stack_size = 2 + input_num_row = 11 + input_num_col = 12 + + for data_format in ['channels_first', 'channels_last']: + if data_format == 'channels_first': + inputs = np.random.rand(num_samples, stack_size, input_num_row, + input_num_col) + else: # tf + inputs = np.random.rand(num_samples, input_num_row, input_num_col, + stack_size) + + # basic test + layer_test(convolutional.UpSampling2D, + kwargs={'size': (2, 2), 'data_format': data_format, 'interpolation': 'bilinear'}, + input_shape=inputs.shape) + + for length_row in [2]: + for length_col in [2, 3]: + layer = convolutional.UpSampling2D( + size=(length_row, length_col), + data_format=data_format) + layer.build(inputs.shape) + outputs = layer(K.variable(inputs)) + np_output = K.eval(outputs) + if data_format == 'channels_first': + assert np_output.shape[2] == length_row * input_num_row + assert np_output.shape[3] == length_col * input_num_col + else: # tf + assert np_output.shape[1] == length_row * input_num_row + assert np_output.shape[2] == length_col * input_num_col + + @pytest.mark.skipif((K.backend() == 'cntk'), reason="cntk does not support it yet") def test_upsampling_3d(): From 4b03e6ae42ebc25a3da14e82e19418d7a2443f8b Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Sat, 3 Feb 2018 22:50:45 -0500 Subject: [PATCH 02/14] bilinear interpolation docstrings added --- keras/backend/tensorflow_backend.py | 1 + keras/layers/convolutional.py | 1 + 2 files changed, 2 insertions(+) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index ac2f88b859a8..463c832c851a 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1913,6 +1913,7 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne height_factor: Positive integer. width_factor: Positive integer. data_format: string, `"channels_last"` or `"channels_first"`. + interpolation: A string, one of `nearest` or `bilinear`. # Returns A tensor. diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index 121099cf73f5..919b66dd3242 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -1565,6 +1565,7 @@ class UpSampling2D(Layer): It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". + interpolation: A string, one of `nearest` or `bilinear`. # Input shape 4D tensor with shape: From b7c10bd628291300b52e913ed7e477784eeca11a Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Mon, 5 Feb 2018 02:02:39 -0500 Subject: [PATCH 03/14] thano_backend.py fix bilinear interpolation ratio --- keras/backend/theano_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index a4ab6f475761..368a664ebb11 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -904,7 +904,8 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne output = repeat_elements(x, height_factor, axis=axis_1) output = repeat_elements(output, width_factor, axis=axis_2) elif interpolation == 'bilinear': - output = theano.tensor.nnet.abstract_conv.bilinear_upsampling(x, ratio=(height_factor, width_factor)) + ratio = height_factor/width_factor + output = theano.tensor.nnet.abstract_conv.bilinear_upsampling(x, ratio=ratio) if hasattr(x, '_keras_shape'): output._keras_shape = list(x._keras_shape) repeat_dim_1 = x._keras_shape[axis_1] From 104f209a3ab41503182502b54d2fc488e0e909b0 Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Mon, 5 Feb 2018 21:32:02 -0500 Subject: [PATCH 04/14] theano_backend.py pep8 & padding fix --- keras/backend/theano_backend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 368a664ebb11..4427cc3cec80 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -904,8 +904,10 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne output = repeat_elements(x, height_factor, axis=axis_1) output = repeat_elements(output, width_factor, axis=axis_2) elif interpolation == 'bilinear': - ratio = height_factor/width_factor - output = theano.tensor.nnet.abstract_conv.bilinear_upsampling(x, ratio=ratio) + ratio = height_factor / width_factor + th_padding = _preprocess_padding('same') + output = theano.tensor.nnet.abstract_conv.bilinear_upsampling( + x, ratio=ratio, border_mode=th_padding) if hasattr(x, '_keras_shape'): output._keras_shape = list(x._keras_shape) repeat_dim_1 = x._keras_shape[axis_1] From 73f031181a0a3c0693768c2e4a5dd48c76d08116 Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Wed, 14 Feb 2018 15:21:37 -0500 Subject: [PATCH 05/14] theano_backend.py remove border_mode as per review --- keras/backend/theano_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 4427cc3cec80..1b59020a26e1 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -907,7 +907,7 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne ratio = height_factor / width_factor th_padding = _preprocess_padding('same') output = theano.tensor.nnet.abstract_conv.bilinear_upsampling( - x, ratio=ratio, border_mode=th_padding) + x, ratio=ratio) if hasattr(x, '_keras_shape'): output._keras_shape = list(x._keras_shape) repeat_dim_1 = x._keras_shape[axis_1] From bd1b3b49c1cdee95e8437073c9c103134eedbe59 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Sat, 25 Aug 2018 16:21:38 +0200 Subject: [PATCH 06/14] Added the mode "bilinear" in the upscaling2D layer. Theano only supports (2, 2) though. --- keras/backend/tensorflow_backend.py | 14 ++++++++---- keras/backend/theano_backend.py | 29 ++++++++++++++++-------- keras/layers/convolutional.py | 3 ++- tests/keras/backend/backend_test.py | 17 ++++++++++++++ tests/keras/layers/convolutional_test.py | 4 +++- 5 files changed, 52 insertions(+), 15 deletions(-) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 0e400cf41816..20061d11cfc8 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1978,8 +1978,8 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne # Raises ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`. """ + original_shape = int_shape(x) if data_format == 'channels_first': - original_shape = int_shape(x) new_shape = tf.shape(x)[2:] new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) x = permute_dimensions(x, [0, 2, 3, 1]) @@ -1988,16 +1988,22 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne elif interpolation == 'bilinear': x = tf.image.resize_bilinear(x, new_shape) else: - raise ValueError('interpolation should be one of "nearest" or "bilinear".') + raise ValueError('interpolation should be one ' + 'of "nearest" or "bilinear".') x = permute_dimensions(x, [0, 3, 1, 2]) x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None, original_shape[3] * width_factor if original_shape[3] is not None else None)) return x elif data_format == 'channels_last': - original_shape = int_shape(x) new_shape = tf.shape(x)[1:3] new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) - x = tf.image.resize_nearest_neighbor(x, new_shape) + if interpolation == 'nearest': + x = tf.image.resize_nearest_neighbor(x, new_shape) + elif interpolation == 'bilinear': + x = tf.image.resize_bilinear(x, new_shape) + else: + raise ValueError('interpolation should be one ' + 'of "nearest" or "bilinear".') x.set_shape((None, original_shape[1] * height_factor if original_shape[1] is not None else None, original_shape[2] * width_factor if original_shape[2] is not None else None, None)) return x diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index f145bf1ff707..0e3971c26e05 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -19,6 +19,7 @@ except ImportError: from theano.sandbox.softsign import softsign as T_softsign +import warnings import numpy as np from .common import floatx from .common import epsilon @@ -965,7 +966,11 @@ def repeat_elements(x, rep, axis): return y -def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): +def resize_images(x, + height_factor, + width_factor, + data_format, + interpolation='nearest'): """Resize the images contained in a 4D tensor of shape - [batch, channels, height, width] (for 'channels_first' data_format) - [batch, height, width, channels] (for 'channels_last' data_format) @@ -985,16 +990,22 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne output = repeat_elements(x, height_factor, axis=axis_1) output = repeat_elements(output, width_factor, axis=axis_2) elif interpolation == 'bilinear': - ratio = height_factor / width_factor - th_padding = _preprocess_padding('same') - output = theano.tensor.nnet.abstract_conv.bilinear_upsampling( - x, ratio=ratio) + if not (height_factor == width_factor == 2): + raise NotImplementedError( + 'Bilinear upscaling with factors other than (2, 2)' + 'is not available when using the Theano backend.') + if data_format == 'channels_last': + output = permute_dimensions(x, [0, 3, 1, 2]) + else: + output = x + output = T.nnet.abstract_conv.bilinear_upsampling(output, + ratio=height_factor) + if data_format == 'channels_last': + output = permute_dimensions(output, [0, 2, 3, 1]) if hasattr(x, '_keras_shape'): output._keras_shape = list(x._keras_shape) - repeat_dim_1 = x._keras_shape[axis_1] - repeat_dim_2 = x._keras_shape[axis_2] - output._keras_shape[axis_1] = repeat_dim_1 * rep - output._keras_shape[axis_2] = repeat_dim_2 * rep + output._keras_shape[axis_1] *= height_factor + output._keras_shape[axis_2] *= width_factor output._keras_shape = tuple(output._keras_shape) else: raise ValueError('interpolation should be one of "nearest" or "bilinear".') diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index a1299efec0f8..9a997f912adf 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -1940,7 +1940,8 @@ def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwa self.size = conv_utils.normalize_tuple(size, 2, 'size') self.input_spec = InputSpec(ndim=4) if interpolation not in ['nearest', 'bilinear']: - raise ValueError('interpolation should be one of "nearest" or "bilinear".') + raise ValueError('interpolation should be one ' + 'of "nearest" or "bilinear".') self.interpolation = interpolation def compute_output_shape(self, input_shape): diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index 726207110bc0..6b454cbb45eb 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -1323,6 +1323,23 @@ def test_resize_images(self): K.resize_images(K.variable(xval), 2, 2, data_format='channels_middle') + @staticmethod + def _helper_bilinear(height_factor, width_factor): + x_shape = (2, 3, 4, 5) + for data_format in ['channels_first', 'channels_last']: + check_single_tensor_operation('resize_images', x_shape, + [KTF, KTH], + height_factor=height_factor, + width_factor=width_factor, + data_format=data_format, + interpolation='bilinear') + + @pytest.mark.skipif(K.backend() == 'cntk', reason='Not supported.') + def test_resize_images_bilinear(self): + self._helper_bilinear(2, 2) + with pytest.raises(NotImplementedError): + self._helper_bilinear(4, 4) + def test_resize_volumes(self): for data_format in ['channels_first', 'channels_last']: shape = (5, 5, 5) diff --git a/tests/keras/layers/convolutional_test.py b/tests/keras/layers/convolutional_test.py index d61ba490555a..88bd108e2b42 100644 --- a/tests/keras/layers/convolutional_test.py +++ b/tests/keras/layers/convolutional_test.py @@ -880,7 +880,9 @@ def test_upsampling_2d_bilinear(): # basic test layer_test(convolutional.UpSampling2D, - kwargs={'size': (2, 2), 'data_format': data_format, 'interpolation': 'bilinear'}, + kwargs={'size': (2, 2), + 'data_format': data_format, + 'interpolation': 'bilinear'}, input_shape=inputs.shape) for length_row in [2]: From 2833ee4b3b804a485f3f2e2639634b293676d9d6 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Sat, 25 Aug 2018 16:30:01 +0200 Subject: [PATCH 07/14] Added some details about what was possible with bilinear interpolation. --- keras/layers/convolutional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index 9a997f912adf..b89b7618f0e3 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -1917,6 +1917,8 @@ class UpSampling2D(Layer): Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". interpolation: A string, one of `nearest` or `bilinear`. + Note that CNTK does not support yet the `bilinear` upscaling + and that with Theano, only `size=(2, 2)` is possible. # Input shape 4D tensor with shape: From 8b6baf3683e325a6eba12ef7dee1099aca8bd148 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Sat, 25 Aug 2018 17:15:04 +0200 Subject: [PATCH 08/14] Made the lines shorted in the resize implementation of tensorflow. --- keras/backend/tensorflow_backend.py | 61 ++++++++++++++++------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 20061d11cfc8..12fe13525af8 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1962,7 +1962,11 @@ def permute_dimensions(x, pattern): return tf.transpose(x, perm=pattern) -def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): +def resize_images(x, + height_factor, + width_factor, + data_format, + interpolation='nearest'): """Resizes the images contained in a 4D tensor. # Arguments @@ -1978,37 +1982,40 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne # Raises ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`. """ + if data_format == 'channels_first': + rows, cols = 2, 3 + else: + rows, cols = 1, 2 + original_shape = int_shape(x) + new_shape = tf.shape(x)[rows:cols + 1] + new_shape *= tf.constant(np.array([height_factor, width_factor], dtype='int32')) + if data_format == 'channels_first': - new_shape = tf.shape(x)[2:] - new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) x = permute_dimensions(x, [0, 2, 3, 1]) - if interpolation == 'nearest': - x = tf.image.resize_nearest_neighbor(x, new_shape) - elif interpolation == 'bilinear': - x = tf.image.resize_bilinear(x, new_shape) - else: - raise ValueError('interpolation should be one ' - 'of "nearest" or "bilinear".') + if interpolation == 'nearest': + x = tf.image.resize_nearest_neighbor(x, new_shape) + elif interpolation == 'bilinear': + x = tf.image.resize_bilinear(x, new_shape) + else: + raise ValueError('interpolation should be one ' + 'of "nearest" or "bilinear".') + if data_format == 'channels_first': x = permute_dimensions(x, [0, 3, 1, 2]) - x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None, - original_shape[3] * width_factor if original_shape[3] is not None else None)) - return x - elif data_format == 'channels_last': - new_shape = tf.shape(x)[1:3] - new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) - if interpolation == 'nearest': - x = tf.image.resize_nearest_neighbor(x, new_shape) - elif interpolation == 'bilinear': - x = tf.image.resize_bilinear(x, new_shape) - else: - raise ValueError('interpolation should be one ' - 'of "nearest" or "bilinear".') - x.set_shape((None, original_shape[1] * height_factor if original_shape[1] is not None else None, - original_shape[2] * width_factor if original_shape[2] is not None else None, None)) - return x + + if original_shape[rows] is None: + new_height = None else: - raise ValueError('Unknown data_format: ' + str(data_format)) + new_height = original_shape[rows] * height_factor + + if original_shape[cols] is None: + new_width = None + else: + new_width = original_shape[cols] * width_factor + + output_shape = (None, new_height, new_width, None) + x.set_shape(transpose_shape(output_shape, data_format, spatial_axes=(1, 2))) + return x def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): From c4b781ec01411b6757c55097d911af8e50d6ad92 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Sun, 26 Aug 2018 11:32:30 +0200 Subject: [PATCH 09/14] Added a more meaningful error message for bilinear and CNTK. --- keras/backend/cntk_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index a0c1b7d227bc..3b7eeaea4b9b 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -1169,7 +1169,7 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne else: raise ValueError('CNTK Backend: Invalid data_format:', data_format) else: - raise NotImplementedError + raise NotImplementedError('CNTK only supports `nearest` interpolation.') def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): From 9d155e2c71719c3ba04cfd27743f5b17be454ac6 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Sun, 26 Aug 2018 20:11:18 +0200 Subject: [PATCH 10/14] Used ' instead of " for consistency. --- tests/keras/layers/convolutional_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/keras/layers/convolutional_test.py b/tests/keras/layers/convolutional_test.py index 88bd108e2b42..1d5b0febae39 100644 --- a/tests/keras/layers/convolutional_test.py +++ b/tests/keras/layers/convolutional_test.py @@ -863,7 +863,7 @@ def test_upsampling_2d(): @pytest.mark.skipif((K.backend() == 'cntk'), - reason="cntk does not support it yet") + reason='cntk does not support it yet') def test_upsampling_2d_bilinear(): num_samples = 2 stack_size = 2 From 7770684566795b72d295c31ea898b7db5ce5145f Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Mon, 27 Aug 2018 15:17:39 +0200 Subject: [PATCH 11/14] Parametrized the bilinear test. --- tests/keras/backend/backend_test.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index 6b454cbb45eb..6a8395ff8ee9 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -1324,21 +1324,21 @@ def test_resize_images(self): data_format='channels_middle') @staticmethod - def _helper_bilinear(height_factor, width_factor): + def _helper_bilinear(data_format, height_factor, width_factor): x_shape = (2, 3, 4, 5) - for data_format in ['channels_first', 'channels_last']: - check_single_tensor_operation('resize_images', x_shape, - [KTF, KTH], - height_factor=height_factor, - width_factor=width_factor, - data_format=data_format, - interpolation='bilinear') + check_single_tensor_operation('resize_images', x_shape, + [KTF, KTH], + height_factor=height_factor, + width_factor=width_factor, + data_format=data_format, + interpolation='bilinear') @pytest.mark.skipif(K.backend() == 'cntk', reason='Not supported.') - def test_resize_images_bilinear(self): - self._helper_bilinear(2, 2) + @pytest.mark.parametrize('data_format', ['channels_first', 'channels_last']) + def test_resize_images_bilinear(self, data_format): + self._helper_bilinear(data_format, 2, 2) with pytest.raises(NotImplementedError): - self._helper_bilinear(4, 4) + self._helper_bilinear(data_format, 4, 4) def test_resize_volumes(self): for data_format in ['channels_first', 'channels_last']: From 252890634bad111cc351582d93ad7546e854389a Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Mon, 27 Aug 2018 18:07:32 +0200 Subject: [PATCH 12/14] Removed unused import. --- keras/backend/theano_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 0e3971c26e05..cdb4a681c857 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -19,7 +19,6 @@ except ImportError: from theano.sandbox.softsign import softsign as T_softsign -import warnings import numpy as np from .common import floatx from .common import epsilon From 70654da85de07791c299153d42016b84ecbb750c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Chollet?= Date: Tue, 28 Aug 2018 10:18:44 -0700 Subject: [PATCH 13/14] Fix error messages. --- keras/backend/cntk_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index 3b7eeaea4b9b..da1b8e420edf 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -1167,7 +1167,7 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne output = repeat_elements(output, width_factor, axis=2) return output else: - raise ValueError('CNTK Backend: Invalid data_format:', data_format) + raise ValueError('CNTK Backend: Invalid data_format: %s' % data_format) else: raise NotImplementedError('CNTK only supports `nearest` interpolation.') @@ -1184,7 +1184,7 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): output = repeat_elements(output, width_factor, axis=3) return output else: - raise ValueError('CNTK Backend: Invalid data_format:', data_format) + raise ValueError('CNTK Backend: Invalid data_format: %s' % data_format) def repeat_elements(x, rep, axis): From dac12e983c66f75cd24b8845054a108359a8df72 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Wed, 29 Aug 2018 18:38:55 +0200 Subject: [PATCH 14/14] Parametrized tests. --- tests/keras/layers/convolutional_test.py | 57 ++++++++++++------------ 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/tests/keras/layers/convolutional_test.py b/tests/keras/layers/convolutional_test.py index 1d5b0febae39..df3862a8e0a3 100644 --- a/tests/keras/layers/convolutional_test.py +++ b/tests/keras/layers/convolutional_test.py @@ -864,41 +864,42 @@ def test_upsampling_2d(): @pytest.mark.skipif((K.backend() == 'cntk'), reason='cntk does not support it yet') -def test_upsampling_2d_bilinear(): +@pytest.mark.parametrize('data_format', + ['channels_first', 'channels_last']) +def test_upsampling_2d_bilinear(data_format): num_samples = 2 stack_size = 2 input_num_row = 11 input_num_col = 12 - for data_format in ['channels_first', 'channels_last']: - if data_format == 'channels_first': - inputs = np.random.rand(num_samples, stack_size, input_num_row, - input_num_col) - else: # tf - inputs = np.random.rand(num_samples, input_num_row, input_num_col, - stack_size) + if data_format == 'channels_first': + inputs = np.random.rand(num_samples, stack_size, input_num_row, + input_num_col) + else: # tf + inputs = np.random.rand(num_samples, input_num_row, input_num_col, + stack_size) - # basic test - layer_test(convolutional.UpSampling2D, - kwargs={'size': (2, 2), - 'data_format': data_format, - 'interpolation': 'bilinear'}, - input_shape=inputs.shape) + # basic test + layer_test(convolutional.UpSampling2D, + kwargs={'size': (2, 2), + 'data_format': data_format, + 'interpolation': 'bilinear'}, + input_shape=inputs.shape) - for length_row in [2]: - for length_col in [2, 3]: - layer = convolutional.UpSampling2D( - size=(length_row, length_col), - data_format=data_format) - layer.build(inputs.shape) - outputs = layer(K.variable(inputs)) - np_output = K.eval(outputs) - if data_format == 'channels_first': - assert np_output.shape[2] == length_row * input_num_row - assert np_output.shape[3] == length_col * input_num_col - else: # tf - assert np_output.shape[1] == length_row * input_num_row - assert np_output.shape[2] == length_col * input_num_col + for length_row in [2]: + for length_col in [2, 3]: + layer = convolutional.UpSampling2D( + size=(length_row, length_col), + data_format=data_format) + layer.build(inputs.shape) + outputs = layer(K.variable(inputs)) + np_output = K.eval(outputs) + if data_format == 'channels_first': + assert np_output.shape[2] == length_row * input_num_row + assert np_output.shape[3] == length_col * input_num_col + else: # tf + assert np_output.shape[1] == length_row * input_num_row + assert np_output.shape[2] == length_col * input_num_col @pytest.mark.skipif((K.backend() == 'cntk'),