Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,17 +1156,20 @@ def permute_dimensions(x, pattern):
return C.transpose(x, axis)


def resize_images(x, height_factor, width_factor, data_format):
if data_format == 'channels_first':
output = repeat_elements(x, height_factor, axis=2)
output = repeat_elements(output, width_factor, axis=3)
return output
elif data_format == 'channels_last':
output = repeat_elements(x, height_factor, axis=1)
output = repeat_elements(output, width_factor, axis=2)
return output
def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'):
if interpolation == 'nearest':
if data_format == 'channels_first':
output = repeat_elements(x, height_factor, axis=2)
output = repeat_elements(output, width_factor, axis=3)
return output
elif data_format == 'channels_last':
output = repeat_elements(x, height_factor, axis=1)
output = repeat_elements(output, width_factor, axis=2)
return output
else:
raise ValueError('CNTK Backend: Invalid data_format: %s' % data_format)
else:
raise ValueError('CNTK Backend: Invalid data_format:', data_format)
raise NotImplementedError('CNTK only supports `nearest` interpolation.')


def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
Expand All @@ -1181,7 +1184,7 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
output = repeat_elements(output, width_factor, axis=3)
return output
else:
raise ValueError('CNTK Backend: Invalid data_format:', data_format)
raise ValueError('CNTK Backend: Invalid data_format: %s' % data_format)


def repeat_elements(x, rep, axis):
Expand Down
51 changes: 35 additions & 16 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,14 +1962,19 @@ def permute_dimensions(x, pattern):
return tf.transpose(x, perm=pattern)


def resize_images(x, height_factor, width_factor, data_format):
def resize_images(x,
height_factor,
width_factor,
data_format,
interpolation='nearest'):
"""Resizes the images contained in a 4D tensor.

# Arguments
x: Tensor or variable to resize.
height_factor: Positive integer.
width_factor: Positive integer.
data_format: string, `"channels_last"` or `"channels_first"`.
interpolation: A string, one of `nearest` or `bilinear`.

# Returns
A tensor.
Expand All @@ -1978,25 +1983,39 @@ def resize_images(x, height_factor, width_factor, data_format):
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
"""
if data_format == 'channels_first':
original_shape = int_shape(x)
new_shape = tf.shape(x)[2:]
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
rows, cols = 2, 3
else:
rows, cols = 1, 2

original_shape = int_shape(x)
new_shape = tf.shape(x)[rows:cols + 1]
new_shape *= tf.constant(np.array([height_factor, width_factor], dtype='int32'))

if data_format == 'channels_first':
x = permute_dimensions(x, [0, 2, 3, 1])
if interpolation == 'nearest':
x = tf.image.resize_nearest_neighbor(x, new_shape)
elif interpolation == 'bilinear':
x = tf.image.resize_bilinear(x, new_shape)
else:
raise ValueError('interpolation should be one '
'of "nearest" or "bilinear".')
if data_format == 'channels_first':
x = permute_dimensions(x, [0, 3, 1, 2])
x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None,
original_shape[3] * width_factor if original_shape[3] is not None else None))
return x
elif data_format == 'channels_last':
original_shape = int_shape(x)
new_shape = tf.shape(x)[1:3]
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
x = tf.image.resize_nearest_neighbor(x, new_shape)
x.set_shape((None, original_shape[1] * height_factor if original_shape[1] is not None else None,
original_shape[2] * width_factor if original_shape[2] is not None else None, None))
return x

if original_shape[rows] is None:
new_height = None
else:
raise ValueError('Unknown data_format: ' + str(data_format))
new_height = original_shape[rows] * height_factor

if original_shape[cols] is None:
new_width = None
else:
new_width = original_shape[cols] * width_factor

output_shape = (None, new_height, new_width, None)
x.set_shape(transpose_shape(output_shape, data_format, spatial_axes=(1, 2)))
return x


def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
Expand Down
42 changes: 35 additions & 7 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,24 +965,52 @@ def repeat_elements(x, rep, axis):
return y


def resize_images(x, height_factor, width_factor, data_format):
def resize_images(x,
height_factor,
width_factor,
data_format,
interpolation='nearest'):
"""Resize the images contained in a 4D tensor of shape
- [batch, channels, height, width] (for 'channels_first' data_format)
- [batch, height, width, channels] (for 'channels_last' data_format)
by a factor of (height_factor, width_factor). Both factors should be
positive integers.
"""
if data_format == 'channels_first':
output = repeat_elements(x, height_factor, axis=2)
output = repeat_elements(output, width_factor, axis=3)
return output
axis_1 = 2
axis_2 = 3
elif data_format == 'channels_last':
output = repeat_elements(x, height_factor, axis=1)
output = repeat_elements(output, width_factor, axis=2)
return output
axis_1 = 1
axis_2 = 2
else:
raise ValueError('Invalid data_format:', data_format)

if interpolation == 'nearest':
output = repeat_elements(x, height_factor, axis=axis_1)
output = repeat_elements(output, width_factor, axis=axis_2)
elif interpolation == 'bilinear':
if not (height_factor == width_factor == 2):
raise NotImplementedError(
'Bilinear upscaling with factors other than (2, 2)'
'is not available when using the Theano backend.')
if data_format == 'channels_last':
output = permute_dimensions(x, [0, 3, 1, 2])
else:
output = x
output = T.nnet.abstract_conv.bilinear_upsampling(output,
ratio=height_factor)
if data_format == 'channels_last':
output = permute_dimensions(output, [0, 2, 3, 1])
if hasattr(x, '_keras_shape'):
output._keras_shape = list(x._keras_shape)
output._keras_shape[axis_1] *= height_factor
output._keras_shape[axis_2] *= width_factor
output._keras_shape = tuple(output._keras_shape)
else:
raise ValueError('interpolation should be one of "nearest" or "bilinear".')

return output


def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
"""Resize the volume contained in a 5D tensor of shape
Expand Down
11 changes: 9 additions & 2 deletions keras/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,6 +1916,9 @@ class UpSampling2D(Layer):
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
interpolation: A string, one of `nearest` or `bilinear`.
Note that CNTK does not support yet the `bilinear` upscaling
and that with Theano, only `size=(2, 2)` is possible.

# Input shape
4D tensor with shape:
Expand All @@ -1933,11 +1936,15 @@ class UpSampling2D(Layer):
"""

@interfaces.legacy_upsampling2d_support
def __init__(self, size=(2, 2), data_format=None, **kwargs):
def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs):
super(UpSampling2D, self).__init__(**kwargs)
self.data_format = K.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 2, 'size')
self.input_spec = InputSpec(ndim=4)
if interpolation not in ['nearest', 'bilinear']:
raise ValueError('interpolation should be one '
'of "nearest" or "bilinear".')
self.interpolation = interpolation

def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
Expand All @@ -1957,7 +1964,7 @@ def compute_output_shape(self, input_shape):

def call(self, inputs):
return K.resize_images(inputs, self.size[0], self.size[1],
self.data_format)
self.data_format, self.interpolation)

def get_config(self):
config = {'size': self.size,
Expand Down
17 changes: 17 additions & 0 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,23 @@ def test_resize_images(self):
K.resize_images(K.variable(xval), 2, 2,
data_format='channels_middle')

@staticmethod
def _helper_bilinear(data_format, height_factor, width_factor):
x_shape = (2, 3, 4, 5)
check_single_tensor_operation('resize_images', x_shape,
[KTF, KTH],
height_factor=height_factor,
width_factor=width_factor,
data_format=data_format,
interpolation='bilinear')

@pytest.mark.skipif(K.backend() == 'cntk', reason='Not supported.')
@pytest.mark.parametrize('data_format', ['channels_first', 'channels_last'])
def test_resize_images_bilinear(self, data_format):
self._helper_bilinear(data_format, 2, 2)
with pytest.raises(NotImplementedError):
self._helper_bilinear(data_format, 4, 4)

def test_resize_volumes(self):
for data_format in ['channels_first', 'channels_last']:
shape = (5, 5, 5)
Expand Down
40 changes: 40 additions & 0 deletions tests/keras/layers/convolutional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,46 @@ def test_upsampling_2d():
assert_allclose(np_output, expected_out)


@pytest.mark.skipif((K.backend() == 'cntk'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this test parameterized (see #10975)

reason='cntk does not support it yet')
@pytest.mark.parametrize('data_format',
['channels_first', 'channels_last'])
def test_upsampling_2d_bilinear(data_format):
num_samples = 2
stack_size = 2
input_num_row = 11
input_num_col = 12

if data_format == 'channels_first':
inputs = np.random.rand(num_samples, stack_size, input_num_row,
input_num_col)
else: # tf
inputs = np.random.rand(num_samples, input_num_row, input_num_col,
stack_size)

# basic test
layer_test(convolutional.UpSampling2D,
kwargs={'size': (2, 2),
'data_format': data_format,
'interpolation': 'bilinear'},
input_shape=inputs.shape)

for length_row in [2]:
for length_col in [2, 3]:
layer = convolutional.UpSampling2D(
size=(length_row, length_col),
data_format=data_format)
layer.build(inputs.shape)
outputs = layer(K.variable(inputs))
np_output = K.eval(outputs)
if data_format == 'channels_first':
assert np_output.shape[2] == length_row * input_num_row
assert np_output.shape[3] == length_col * input_num_col
else: # tf
assert np_output.shape[1] == length_row * input_num_row
assert np_output.shape[2] == length_col * input_num_col


@pytest.mark.skipif((K.backend() == 'cntk'),
reason="cntk does not support it yet")
def test_upsampling_3d():
Expand Down