Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Made the lines shorted in the resize implementation of tensorflow.
  • Loading branch information
gabrieldemarmiesse committed Aug 25, 2018
commit 8b6baf3683e325a6eba12ef7dee1099aca8bd148
61 changes: 34 additions & 27 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,7 +1962,11 @@ def permute_dimensions(x, pattern):
return tf.transpose(x, perm=pattern)


def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'):
def resize_images(x,
height_factor,
width_factor,
data_format,
interpolation='nearest'):
"""Resizes the images contained in a 4D tensor.

# Arguments
Expand All @@ -1978,37 +1982,40 @@ def resize_images(x, height_factor, width_factor, data_format, interpolation='ne
# Raises
ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
"""
if data_format == 'channels_first':
rows, cols = 2, 3
else:
rows, cols = 1, 2

original_shape = int_shape(x)
new_shape = tf.shape(x)[rows:cols + 1]
new_shape *= tf.constant(np.array([height_factor, width_factor], dtype='int32'))

if data_format == 'channels_first':
new_shape = tf.shape(x)[2:]
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
x = permute_dimensions(x, [0, 2, 3, 1])
if interpolation == 'nearest':
x = tf.image.resize_nearest_neighbor(x, new_shape)
elif interpolation == 'bilinear':
x = tf.image.resize_bilinear(x, new_shape)
else:
raise ValueError('interpolation should be one '
'of "nearest" or "bilinear".')
if interpolation == 'nearest':
x = tf.image.resize_nearest_neighbor(x, new_shape)
elif interpolation == 'bilinear':
x = tf.image.resize_bilinear(x, new_shape)
else:
raise ValueError('interpolation should be one '
'of "nearest" or "bilinear".')
if data_format == 'channels_first':
x = permute_dimensions(x, [0, 3, 1, 2])
x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None,
original_shape[3] * width_factor if original_shape[3] is not None else None))
return x
elif data_format == 'channels_last':
new_shape = tf.shape(x)[1:3]
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
if interpolation == 'nearest':
x = tf.image.resize_nearest_neighbor(x, new_shape)
elif interpolation == 'bilinear':
x = tf.image.resize_bilinear(x, new_shape)
else:
raise ValueError('interpolation should be one '
'of "nearest" or "bilinear".')
x.set_shape((None, original_shape[1] * height_factor if original_shape[1] is not None else None,
original_shape[2] * width_factor if original_shape[2] is not None else None, None))
return x

if original_shape[rows] is None:
new_height = None
else:
raise ValueError('Unknown data_format: ' + str(data_format))
new_height = original_shape[rows] * height_factor

if original_shape[cols] is None:
new_width = None
else:
new_width = original_shape[cols] * width_factor

output_shape = (None, new_height, new_width, None)
x.set_shape(transpose_shape(output_shape, data_format, spatial_axes=(1, 2)))
return x


def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
Expand Down