Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add 3D Convolutional layers(Convolution3D, MaxPooling3D, AveragePooli…
…ng3D, UpSampling3D and ZeroPadding3D, working with theano backend)

---------------------------------------
Squashed from the following commits
add Convolution3D and MaxPooling3D layers
fix 5D tensor in theano, add examples
update conv3d, pool3d, add resize_volumes and spatial_3d_padding
update Convolution3D, MaxPooling3D and AveragePooling3D, add UpSampling3D and ZeroPadding3D
add test functions for Convolution3D, MaxPooling3D, AveragePooling3D, ZeroPadding3D and UpSampling3D
small fix by changing pad_z to pad_t
update comment
skip some tests for tenforflow, @pytest.mark.skipif(K._BACKEND != theano, reason="Requires Theano backend")
use autopep8 to fix the code to match pep8 coding style
small fix (caused by autopep8)
small fix (caused by autopep8)
small fix (caused by autopep8)
fixed the document string for all newly added layers
remove the example and the dataset for 3d
add error messge for tensorflow backend
support stride in pool3d
Rename "params" to "trainable_weights"
change notations and docstrings for 3D layers
fix pep8 error
change variable name in test code
small fix for pep8
add error message and docstring for strides in conv3d
fix test error caused by wrong strides in conv3d
support strides in conv3d by slicing the output
add if statement for stride (1,1,1)
fix get_config according to mdering, and other small fix
fix model_from_json issue by passing a 3d border_mode
fix according to jruales' review
change docstring in Convolution3D
delete docstring about TensorFlow
change docstring in Convolution3D and theano_backend
---------------------------------------

Author:    Wei OUYANG <[email protected]>
  • Loading branch information
oeway committed Feb 22, 2016
commit 3ecf201aeaf97c90fd8ba02a61c4a766b4b0cfb9
180 changes: 178 additions & 2 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from theano import tensor as T
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
from theano.tensor.signal import downsample
from theano.tensor.nnet import conv3d2d
import numpy as np
from .common import _FLOATX, _EPSILON

Expand Down Expand Up @@ -41,6 +42,7 @@ def placeholder(shape=None, ndim=None, dtype=_FLOATX, name=None):
raise Exception('Specify either a shape or ndim value.')
if shape is not None:
ndim = len(shape)

broadcast = (False,) * ndim
return T.TensorType(dtype, broadcast)(name)

Expand Down Expand Up @@ -265,6 +267,27 @@ def resize_images(X, height_factor, width_factor, dim_ordering):
raise Exception('Invalid dim_ordering: ' + dim_ordering)


def resize_volumes(X, depth_factor, height_factor, width_factor, dim_ordering):
'''Resize the volume contained in a 5D tensor of shape
- [batch, channels, depth, height, width] (for 'th' dim_ordering)
- [batch, depth, height, width, channels] (for 'tf' dim_ordering)
by a factor of (depth_factor, height_factor, width_factor).
Both factors should be positive integers.
'''
if dim_ordering == 'th':
output = repeat_elements(X, depth_factor, axis=2)
output = repeat_elements(output, height_factor, axis=3)
output = repeat_elements(output, width_factor, axis=4)
return output
elif dim_ordering == 'tf':
output = repeat_elements(X, depth_factor, axis=1)
output = repeat_elements(output, height_factor, axis=2)
output = repeat_elements(output, width_factor, axis=3)
return output
else:
raise Exception('Invalid dim_ordering: ' + dim_ordering)


def repeat(x, n):
'''Repeat a 2D tensor.

Expand Down Expand Up @@ -357,6 +380,42 @@ def spatial_2d_padding(x, padding=(1, 1), dim_ordering='th'):
raise Exception('Invalid dim_ordering: ' + dim_ordering)
return T.set_subtensor(output[indices], x)


def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'):
'''Pad the 2nd, 3rd and 4th dimensions of a 5D tensor
with "padding[0]", "padding[1]" and "padding[2]" (resp.) zeros left and right.
'''
input_shape = x.shape
if dim_ordering == 'th':
output_shape = (input_shape[0],
input_shape[1],
input_shape[2] + 2 * padding[0],
input_shape[3] + 2 * padding[1],
input_shape[4] + 2 * padding[2])
output = T.zeros(output_shape)
indices = (slice(None),
slice(None),
slice(padding[0], input_shape[2] + padding[0]),
slice(padding[1], input_shape[3] + padding[1]),
slice(padding[2], input_shape[4] + padding[2]))

elif dim_ordering == 'tf':
output_shape = (input_shape[0],
input_shape[1] + 2 * padding[0],
input_shape[2] + 2 * padding[1],
input_shape[3] + 2 * padding[2],
input_shape[4])
output = T.zeros(output_shape)
indices = (slice(None),
slice(padding[0], input_shape[1] + padding[0]),
slice(padding[1], input_shape[2] + padding[1]),
slice(padding[2], input_shape[3] + padding[2]),
slice(None))
else:
raise Exception('Invalid dim_ordering: ' + dim_ordering)
return T.set_subtensor(output[indices], x)


# VALUE MANIPULATION


Expand Down Expand Up @@ -397,7 +456,6 @@ def gradients(loss, variables):
def rnn(step_function, inputs, initial_states,
go_backwards=False, mask=None, constants=None):
'''Iterates over the time dimension of a tensor.

Parameters
----------
inputs: tensor of temporal data of shape (samples, time, ...)
Expand All @@ -420,7 +478,6 @@ def rnn(step_function, inputs, initial_states,
mask: binary tensor with shape (samples, time),
with a zero for every element that is masked.
constants: a list of constant values passed at each step.

Returns
-------
A tuple (last_output, outputs, new_states).
Expand Down Expand Up @@ -647,6 +704,67 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th',
return conv_out


def conv3d(x, kernel, strides=(1, 1, 1), border_mode='valid', dim_ordering='th',
volume_shape=None, filter_shape=None):
'''
Run on cuDNN if available.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conv3d2d that is being used in this function uses tensor.nnet.conv2d for the convolutions. This page mentions that nnet.conv2d uses cuDNN by default if available, so I think that this function will indeed take advantage of cuDNN 2d convolution operation if available.

This makes me wonder why the conv2d function in this same file uses dnn.dnn_conv instead of tensor.nnet.conv2d when cuDNN is detected, instead of simply using nnet.conv2d.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is because nnet.conv2d is newly introduced by theano.

As of December 2015, a new conv2d interface has been introduced. nnet.conv2d defines an
...
By default, if cuDNN is available, we will use it, otherwise we will fall back to using the gemm version (slower then cuDNN in most cases and uses more memory).
...
Old conv2d interface is still accessible through nnet.conv.conv2d.

Ref: http://deeplearning.net/software/theano/library/tensor/nnet/conv.html

In conv2d of theano_backend.py, it still using nnet.conv.conv2d, which is the old interface.
I think you are right, use nnet.conv2d should be better, maybe conv2d in theano_backend.py should be updated, but that should be done in another PR.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conv3d2d does use the new version nnet.conv2d so it should be fine https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/conv3d2d.py#L244

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are talking about the conv2d in theano_backend.py used by
Convolution2D, when I say it should be updated is because it's using an
outdated interface, it's not nnet.conv2d.
On Fri, Feb 19, 2016 at 03:47 Joaquín Ruales [email protected]
wrote:

In keras/backend/theano_backend.py
#1623 (comment):

@@ -642,6 +700,67 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th',
return conv_out

+def conv3d(x, kernel, strides=(1, 1, 1), border_mode='valid', dim_ordering='th',

  •       image_shape=None, filter_shape=None):
    
  • '''
  • Run on cuDNN if available.

conv3d2d does use the new version nnet.conv2d so it should be fine
https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/conv3d2d.py#L244


Reply to this email directly or view it on GitHub
https://github.com/fchollet/keras/pull/1623/files#r53417092.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood.

border_mode: string, "same" or "valid".
'''
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))

if border_mode not in {'same', 'valid'}:
raise Exception('Invalid border mode: ' + str(border_mode))

if dim_ordering == 'tf':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, conv_dim1, conv_dim2, conv_dim3)
# TF input shape: (samples, conv_dim1, conv_dim2, conv_dim3, input_depth)
# TH kernel shape: (out_depth, input_depth, kernel_dim1, kernel_dim2, kernel_dim3)
# TF kernel shape: (kernel_dim1, kernel_dim2, kernel_dim3, input_depth, out_depth)
x = x.dimshuffle((0, 4, 1, 2, 3))
kernel = kernel.dimshuffle((4, 3, 0, 1, 2))
if volume_shape:
volume_shape = (volume_shape[0], volume_shape[4],
volume_shape[1], volume_shape[2], volume_shape[3])
if filter_shape:
filter_shape = (filter_shape[4], filter_shape[3],
filter_shape[0], filter_shape[1], filter_shape[2])

if border_mode == 'same':
assert(strides == (1, 1, 1))
pad_dim1 = (kernel.shape[2] - 1)
pad_dim2 = (kernel.shape[3] - 1)
pad_dim3 = (kernel.shape[4] - 1)
output_shape = (x.shape[0], x.shape[1],
x.shape[2] + pad_dim1,
x.shape[3] + pad_dim2,
x.shape[4] + pad_dim3)
output = T.zeros(output_shape)
indices = (slice(None), slice(None),
slice(pad_dim1 // 2, x.shape[2] + pad_dim1 // 2),
slice(pad_dim2 // 2, x.shape[3] + pad_dim2 // 2),
slice(pad_dim3 // 2, x.shape[4] + pad_dim3 // 2))
x = T.set_subtensor(output[indices], x)
border_mode = 'valid'

border_mode_3d = (border_mode, border_mode, border_mode)
conv_out = conv3d2d.conv3d(signals=x.dimshuffle(0, 2, 1, 3, 4),
filters=kernel.dimshuffle(0, 2, 1, 3, 4),
border_mode=border_mode_3d)
conv_out = conv_out.dimshuffle(0, 2, 1, 3, 4)

# support strides by manually slicing the output
if strides != (1, 1, 1):
conv_out = conv_out[:, :, ::strides[0], ::strides[1], ::strides[2]]

if dim_ordering == 'tf':
conv_out = conv_out.dimshuffle((0, 2, 3, 4, 1))

return conv_out


def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
dim_ordering='th', pool_mode='max'):
if border_mode == 'same':
Expand Down Expand Up @@ -682,6 +800,64 @@ def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
return pool_out


def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid',
dim_ordering='th', pool_mode='max'):
if border_mode == 'same':
# TODO: add implementation for border_mode="same"
raise Exception('border_mode="same" not supported with Theano.')
elif border_mode == 'valid':
ignore_border = True
padding = (0, 0)
else:
raise Exception('Invalid border mode: ' + str(border_mode))

if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))

if dim_ordering == 'tf':
x = x.dimshuffle((0, 4, 1, 2, 3))

if pool_mode == 'max':
# pooling over conv_dim2, conv_dim1 (last two channels)
output = downsample.max_pool_2d(input=x.dimshuffle(0, 1, 4, 3, 2),
ds=(pool_size[1], pool_size[0]),
st=(strides[1], strides[0]),
ignore_border=ignore_border,
padding=padding,
mode='max')

# pooling over conv_dim3
pool_out = downsample.max_pool_2d(input=output.dimshuffle(0, 1, 4, 3, 2),
ds=(1, pool_size[2]),
st=(1, strides[2]),
ignore_border=ignore_border,
padding=padding,
mode='max')

elif pool_mode == 'avg':
# pooling over conv_dim2, conv_dim1 (last two channels)
output = downsample.max_pool_2d(input=x.dimshuffle(0, 1, 4, 3, 2),
ds=(pool_size[1], pool_size[0]),
st=(strides[1], strides[0]),
ignore_border=ignore_border,
padding=padding,
mode='average_exc_pad')

# pooling over conv_dim3
pool_out = downsample.max_pool_2d(input=output.dimshuffle(0, 1, 4, 3, 2),
ds=(1, pool_size[2]),
st=(1, strides[2]),
ignore_border=ignore_border,
padding=padding,
mode='average_exc_pad')
else:
raise Exception('Invalid pooling mode: ' + str(pool_mode))

if dim_ordering == 'tf':
pool_out = pool_out.dimshuffle((0, 2, 3, 4, 1))
return pool_out


# RANDOMNESS


Expand Down
Loading