-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Conv2DTranspose supports dilation #11029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c48d7ff to
351e4fe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Looks good to me at a high level.
| padding='valid', | ||
| output_padding=None, | ||
| data_format=None, | ||
| dilation_rate=(1, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This must be added to the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring has already exist at L671.
| 'data_format': 'channels_last', | ||
| 'dilation_rate': (2, 2)}, | ||
| input_shape=(num_samples, num_row, num_col, stack_size), | ||
| fixed_batch_size=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessary, I just follow the test case above, will remove it.
| @keras_test | ||
| @pytest.mark.skipif((K.backend() == 'cntk'), | ||
| reason='cntk only supports dilated conv transpose on GPU') | ||
| def test_conv2d_transpose_dilation(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this unit test really sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can specify the input, a constant kernel and check the output, but due to Theano backend need to rotate the kernel which makes the test complicated, and we need to hard code output matrix. The following code is I run for offline test, I will wrap them as a unit test.
from keras.utils.conv_utils import convert_kernel
kernel = np.arange(27).reshape((3, 3, 1, 3))
print(kernel.shape)
# kernel = convert_kernel(kernel) # Enable this line for Theano backend
inputs = Input(shape=(5, 6, 3))
m = Conv2DTranspose(1, kernel_size=3, padding='same', dilation_rate=2, kernel_initializer=Constant(kernel))(inputs)
model = Model(inputs=inputs, outputs=m)
model.summary()
x = np.arange(180).reshape((2, 5, 6, 3)).astype(np.float32)
y = model.predict(x)
print(np.squeeze(y))
print(y.shape)
| padding='valid', | ||
| output_padding=None, | ||
| data_format=None, | ||
| dilation_rate=(1, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring has already exist at L671.
| 'data_format': 'channels_last', | ||
| 'dilation_rate': (2, 2), | ||
| 'kernel_initializer': 'ones'}, | ||
| expected_output=expected_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a test case here, kernel initialized with ones, then we don't have special handle for Theano backend, and check the output with the given input. The input data's batch size is only one, that we don't need to hard code too big output matrix.
80ec2b3 to
3b09507
Compare
3b09507 to
1f84ed1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
* keras/master: (327 commits) Added in_train_phase and in_test_phase in the numpy backend. (keras-team#11061) Make sure the data_format argument defaults to ‘chanels_last’ for all 1D sequence layers. Speed up backend tests (keras-team#11051) Skipped some duplicated tests. (keras-team#11049) Used decorators and WITH_NP to avoid tests duplication. (keras-team#11050) Cached the theano compilation directory. (keras-team#11048) Removing duplicated backend tests. (keras-team#11037) [P, RELNOTES] Conv2DTranspose supports dilation (keras-team#11029) Doc Change: Change in shape for CIFAR Datasets (keras-team#11043) Fix line too long in mnist_acgan (keras-team#11040) Enable using last incomplete minibatch (keras-team#8344) Better UX (keras-team#11039) Update lstm text generation example (keras-team#11038) fix a bug, load_weights doesn't return anything (keras-team#11031) Speeding up the tests by reducing the number of K.eval(). (keras-team#11036) [P] Expose monitor value getter for easier subclass (keras-team#11002) [RELNOTES] Added the mode "bilinear" in the upscaling2D layer. (keras-team#10994) Separate pooling test from convolutional test and parameterize test case (keras-team#10975) Fix issue with non-canonical TF version name format. Allow TB callback to display float values. ...
Summary
I saw the requirements for dilated transposed convolution at several places(#8159, tensorflow/tensorflow#21598). This PR makes
Conv2DTransposesupports dilation for all backends.How to test
Except the unit test I added, I also run the following test. As dilated transposed convolution is equivalent to standard transposed convolution with upsampled kernels with effective size
kernel_size + (kernel_size - 1) * (dilation_rate - 1), produced by insertingdilation_rate - 1zeros along consecutive elements across the kernel' spatial dimensions. The following code verify these two scenarios that give the same output:Related Issues
#8159
tensorflow/tensorflow#21598
PR Overview