diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 5ea9a2f16ab5..94cd2a387481 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -2,6 +2,7 @@ from theano import tensor as T from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams from theano.tensor.signal import downsample +from theano.tensor.nnet import conv3d2d import numpy as np from .common import _FLOATX, _EPSILON @@ -41,6 +42,7 @@ def placeholder(shape=None, ndim=None, dtype=_FLOATX, name=None): raise Exception('Specify either a shape or ndim value.') if shape is not None: ndim = len(shape) + broadcast = (False,) * ndim return T.TensorType(dtype, broadcast)(name) @@ -265,6 +267,27 @@ def resize_images(X, height_factor, width_factor, dim_ordering): raise Exception('Invalid dim_ordering: ' + dim_ordering) +def resize_volumes(X, depth_factor, height_factor, width_factor, dim_ordering): + '''Resize the volume contained in a 5D tensor of shape + - [batch, channels, depth, height, width] (for 'th' dim_ordering) + - [batch, depth, height, width, channels] (for 'tf' dim_ordering) + by a factor of (depth_factor, height_factor, width_factor). + Both factors should be positive integers. + ''' + if dim_ordering == 'th': + output = repeat_elements(X, depth_factor, axis=2) + output = repeat_elements(output, height_factor, axis=3) + output = repeat_elements(output, width_factor, axis=4) + return output + elif dim_ordering == 'tf': + output = repeat_elements(X, depth_factor, axis=1) + output = repeat_elements(output, height_factor, axis=2) + output = repeat_elements(output, width_factor, axis=3) + return output + else: + raise Exception('Invalid dim_ordering: ' + dim_ordering) + + def repeat(x, n): '''Repeat a 2D tensor. @@ -357,6 +380,42 @@ def spatial_2d_padding(x, padding=(1, 1), dim_ordering='th'): raise Exception('Invalid dim_ordering: ' + dim_ordering) return T.set_subtensor(output[indices], x) + +def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'): + '''Pad the 2nd, 3rd and 4th dimensions of a 5D tensor + with "padding[0]", "padding[1]" and "padding[2]" (resp.) zeros left and right. + ''' + input_shape = x.shape + if dim_ordering == 'th': + output_shape = (input_shape[0], + input_shape[1], + input_shape[2] + 2 * padding[0], + input_shape[3] + 2 * padding[1], + input_shape[4] + 2 * padding[2]) + output = T.zeros(output_shape) + indices = (slice(None), + slice(None), + slice(padding[0], input_shape[2] + padding[0]), + slice(padding[1], input_shape[3] + padding[1]), + slice(padding[2], input_shape[4] + padding[2])) + + elif dim_ordering == 'tf': + output_shape = (input_shape[0], + input_shape[1] + 2 * padding[0], + input_shape[2] + 2 * padding[1], + input_shape[3] + 2 * padding[2], + input_shape[4]) + output = T.zeros(output_shape) + indices = (slice(None), + slice(padding[0], input_shape[1] + padding[0]), + slice(padding[1], input_shape[2] + padding[1]), + slice(padding[2], input_shape[3] + padding[2]), + slice(None)) + else: + raise Exception('Invalid dim_ordering: ' + dim_ordering) + return T.set_subtensor(output[indices], x) + + # VALUE MANIPULATION @@ -397,7 +456,6 @@ def gradients(loss, variables): def rnn(step_function, inputs, initial_states, go_backwards=False, mask=None, constants=None): '''Iterates over the time dimension of a tensor. - Parameters ---------- inputs: tensor of temporal data of shape (samples, time, ...) @@ -420,7 +478,6 @@ def rnn(step_function, inputs, initial_states, mask: binary tensor with shape (samples, time), with a zero for every element that is masked. constants: a list of constant values passed at each step. - Returns ------- A tuple (last_output, outputs, new_states). @@ -647,6 +704,67 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th', return conv_out +def conv3d(x, kernel, strides=(1, 1, 1), border_mode='valid', dim_ordering='th', + volume_shape=None, filter_shape=None): + ''' + Run on cuDNN if available. + border_mode: string, "same" or "valid". + ''' + if dim_ordering not in {'th', 'tf'}: + raise Exception('Unknown dim_ordering ' + str(dim_ordering)) + + if border_mode not in {'same', 'valid'}: + raise Exception('Invalid border mode: ' + str(border_mode)) + + if dim_ordering == 'tf': + # TF uses the last dimension as channel dimension, + # instead of the 2nd one. + # TH input shape: (samples, input_depth, conv_dim1, conv_dim2, conv_dim3) + # TF input shape: (samples, conv_dim1, conv_dim2, conv_dim3, input_depth) + # TH kernel shape: (out_depth, input_depth, kernel_dim1, kernel_dim2, kernel_dim3) + # TF kernel shape: (kernel_dim1, kernel_dim2, kernel_dim3, input_depth, out_depth) + x = x.dimshuffle((0, 4, 1, 2, 3)) + kernel = kernel.dimshuffle((4, 3, 0, 1, 2)) + if volume_shape: + volume_shape = (volume_shape[0], volume_shape[4], + volume_shape[1], volume_shape[2], volume_shape[3]) + if filter_shape: + filter_shape = (filter_shape[4], filter_shape[3], + filter_shape[0], filter_shape[1], filter_shape[2]) + + if border_mode == 'same': + assert(strides == (1, 1, 1)) + pad_dim1 = (kernel.shape[2] - 1) + pad_dim2 = (kernel.shape[3] - 1) + pad_dim3 = (kernel.shape[4] - 1) + output_shape = (x.shape[0], x.shape[1], + x.shape[2] + pad_dim1, + x.shape[3] + pad_dim2, + x.shape[4] + pad_dim3) + output = T.zeros(output_shape) + indices = (slice(None), slice(None), + slice(pad_dim1 // 2, x.shape[2] + pad_dim1 // 2), + slice(pad_dim2 // 2, x.shape[3] + pad_dim2 // 2), + slice(pad_dim3 // 2, x.shape[4] + pad_dim3 // 2)) + x = T.set_subtensor(output[indices], x) + border_mode = 'valid' + + border_mode_3d = (border_mode, border_mode, border_mode) + conv_out = conv3d2d.conv3d(signals=x.dimshuffle(0, 2, 1, 3, 4), + filters=kernel.dimshuffle(0, 2, 1, 3, 4), + border_mode=border_mode_3d) + conv_out = conv_out.dimshuffle(0, 2, 1, 3, 4) + + # support strides by manually slicing the output + if strides != (1, 1, 1): + conv_out = conv_out[:, :, ::strides[0], ::strides[1], ::strides[2]] + + if dim_ordering == 'tf': + conv_out = conv_out.dimshuffle((0, 2, 3, 4, 1)) + + return conv_out + + def pool2d(x, pool_size, strides=(1, 1), border_mode='valid', dim_ordering='th', pool_mode='max'): if border_mode == 'same': @@ -682,6 +800,64 @@ def pool2d(x, pool_size, strides=(1, 1), border_mode='valid', return pool_out +def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid', + dim_ordering='th', pool_mode='max'): + if border_mode == 'same': + # TODO: add implementation for border_mode="same" + raise Exception('border_mode="same" not supported with Theano.') + elif border_mode == 'valid': + ignore_border = True + padding = (0, 0) + else: + raise Exception('Invalid border mode: ' + str(border_mode)) + + if dim_ordering not in {'th', 'tf'}: + raise Exception('Unknown dim_ordering ' + str(dim_ordering)) + + if dim_ordering == 'tf': + x = x.dimshuffle((0, 4, 1, 2, 3)) + + if pool_mode == 'max': + # pooling over conv_dim2, conv_dim1 (last two channels) + output = downsample.max_pool_2d(input=x.dimshuffle(0, 1, 4, 3, 2), + ds=(pool_size[1], pool_size[0]), + st=(strides[1], strides[0]), + ignore_border=ignore_border, + padding=padding, + mode='max') + + # pooling over conv_dim3 + pool_out = downsample.max_pool_2d(input=output.dimshuffle(0, 1, 4, 3, 2), + ds=(1, pool_size[2]), + st=(1, strides[2]), + ignore_border=ignore_border, + padding=padding, + mode='max') + + elif pool_mode == 'avg': + # pooling over conv_dim2, conv_dim1 (last two channels) + output = downsample.max_pool_2d(input=x.dimshuffle(0, 1, 4, 3, 2), + ds=(pool_size[1], pool_size[0]), + st=(strides[1], strides[0]), + ignore_border=ignore_border, + padding=padding, + mode='average_exc_pad') + + # pooling over conv_dim3 + pool_out = downsample.max_pool_2d(input=output.dimshuffle(0, 1, 4, 3, 2), + ds=(1, pool_size[2]), + st=(1, strides[2]), + ignore_border=ignore_border, + padding=padding, + mode='average_exc_pad') + else: + raise Exception('Invalid pooling mode: ' + str(pool_mode)) + + if dim_ordering == 'tf': + pool_out = pool_out.dimshuffle((0, 2, 3, 4, 1)) + return pool_out + + # RANDOMNESS diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index 2ab14a7080ed..08ad068def49 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -343,6 +343,196 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) +class Convolution3D(Layer): + '''Convolution operator for filtering windows of three-dimensional inputs. + When using this layer as the first layer in a model, + provide the keyword argument `input_shape` + (tuple of integers, does not include the sample axis), + e.g. `input_shape=(3, 10, 128, 128)` for 10 frames of 128x128 RGB pictures. + + Note: this layer will only work with Theano for the time being. + + # Input shape + 5D tensor with shape: + `(samples, channels, conv_dim1, conv_dim2, conv_dim3)` if dim_ordering='th' + or 5D tensor with shape: + `(samples, conv_dim1, conv_dim2, conv_dim3, channels)` if dim_ordering='tf'. + + # Output shape + 5D tensor with shape: + `(samples, nb_filter, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if dim_ordering='th' + or 5D tensor with shape: + `(samples, new_conv_dim1, new_conv_dim2, new_conv_dim3, nb_filter)` if dim_ordering='tf'. + `new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have changed due to padding. + + # Arguments + nb_filter: Number of convolution filters to use. + kernel_dim1: Length of the first dimension in the covolution kernel. + kernel_dim2: Length of the second dimension in the convolution kernel. + kernel_dim3: Length of the third dimension in the convolution kernel. + init: name of initialization function for the weights of the layer + (see [initializations](../initializations.md)), or alternatively, + Theano function to use for weights initialization. + This parameter is only relevant if you don't pass + a `weights` argument. + activation: name of activation function to use + (see [activations](../activations.md)), + or alternatively, elementwise Theano function. + If you don't specify anything, no activation is applied + (ie. "linear" activation: a(x) = x). + weights: list of numpy arrays to set as initial weights. + border_mode: 'valid' or 'same'. + subsample: tuple of length 3. Factor by which to subsample output. + Also called strides elsewhere. + Note: 'subsample' is implemented by slicing the output of conv3d with strides=(1,1,1). + W_regularizer: instance of [WeightRegularizer](../regularizers.md) + (eg. L1 or L2 regularization), applied to the main weights matrix. + b_regularizer: instance of [WeightRegularizer](../regularizers.md), + applied to the bias. + activity_regularizer: instance of [ActivityRegularizer](../regularizers.md), + applied to the network output. + W_constraint: instance of the [constraints](../constraints.md) module + (eg. maxnorm, nonneg), applied to the main weights matrix. + b_constraint: instance of the [constraints](../constraints.md) module, + applied to the bias. + dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension + (the depth) is at index 1, in 'tf' mode is it at index 4. + ''' + input_ndim = 5 + + def __init__(self, nb_filter, kernel_dim1, kernel_dim2, kernel_dim3, + init='glorot_uniform', activation='linear', weights=None, + border_mode='valid', subsample=(1, 1, 1), dim_ordering='th', + W_regularizer=None, b_regularizer=None, activity_regularizer=None, + W_constraint=None, b_constraint=None, **kwargs): + if K._BACKEND != 'theano': + raise Exception(self.__class__.__name__ + + ' is currently only working with Theano backend.') + if border_mode not in {'valid', 'same'}: + raise Exception('Invalid border mode for Convolution3D:', border_mode) + self.nb_filter = nb_filter + self.kernel_dim1 = kernel_dim1 + self.kernel_dim2 = kernel_dim2 + self.kernel_dim3 = kernel_dim3 + self.init = initializations.get(init) + self.activation = activations.get(activation) + assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}' + self.border_mode = border_mode + self.subsample = tuple(subsample) + assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}' + self.dim_ordering = dim_ordering + + self.W_regularizer = regularizers.get(W_regularizer) + self.b_regularizer = regularizers.get(b_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) + + self.W_constraint = constraints.get(W_constraint) + self.b_constraint = constraints.get(b_constraint) + self.constraints = [self.W_constraint, self.b_constraint] + + self.initial_weights = weights + self.input = K.placeholder(ndim=5) + super(Convolution3D, self).__init__(**kwargs) + + def build(self): + + if self.dim_ordering == 'th': + stack_size = self.input_shape[1] + self.W_shape = (self.nb_filter, stack_size, + self.kernel_dim1, self.kernel_dim2, self.kernel_dim3) + elif self.dim_ordering == 'tf': + stack_size = self.input_shape[4] + self.W_shape = (self.kernel_dim1, self.kernel_dim2, self.kernel_dim3, + stack_size, self.nb_filter) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + self.W = self.init(self.W_shape) + self.b = K.zeros((self.nb_filter,)) + self.trainable_weights = [self.W, self.b] + self.regularizers = [] + + if self.W_regularizer: + self.W_regularizer.set_param(self.W) + self.regularizers.append(self.W_regularizer) + + if self.b_regularizer: + self.b_regularizer.set_param(self.b) + self.regularizers.append(self.b_regularizer) + + if self.activity_regularizer: + self.activity_regularizer.set_layer(self) + self.regularizers.append(self.activity_regularizer) + + if self.initial_weights is not None: + self.set_weights(self.initial_weights) + del self.initial_weights + + @property + def output_shape(self): + input_shape = self.input_shape + if self.dim_ordering == 'th': + conv_dim1 = input_shape[2] + conv_dim2 = input_shape[3] + conv_dim3 = input_shape[4] + elif self.dim_ordering == 'tf': + conv_dim1 = input_shape[1] + conv_dim2 = input_shape[2] + conv_dim3 = input_shape[3] + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + conv_dim1 = conv_output_length(conv_dim1, self.kernel_dim1, + self.border_mode, self.subsample[0]) + conv_dim2 = conv_output_length(conv_dim2, self.kernel_dim2, + self.border_mode, self.subsample[1]) + conv_dim3 = conv_output_length(conv_dim3, self.kernel_dim3, + self.border_mode, self.subsample[2]) + + if self.dim_ordering == 'th': + return (input_shape[0], self.nb_filter, conv_dim1, conv_dim2, conv_dim3) + elif self.dim_ordering == 'tf': + return (input_shape[0], conv_dim1, conv_dim2, conv_dim3, self.nb_filter) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + def get_output(self, train=False): + X = self.get_input(train) + conv_out = K.conv3d(X, self.W, strides=self.subsample, + border_mode=self.border_mode, + dim_ordering=self.dim_ordering, + volume_shape=self.input_shape, + filter_shape=self.W_shape) + + if self.dim_ordering == 'th': + output = conv_out + K.reshape(self.b, (1, self.nb_filter, 1, 1, 1)) + elif self.dim_ordering == 'tf': + output = conv_out + K.reshape(self.b, (1, 1, 1, 1, self.nb_filter)) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + output = self.activation(output) + return output + + def get_config(self): + config = {"name": self.__class__.__name__, + "nb_filter": self.nb_filter, + "kernel_dim1": self.kernel_dim1, + "kernel_dim2": self.kernel_dim2, + "kernel_dim3": self.kernel_dim3, + "dim_ordering": self.dim_ordering, + "init": self.init.__name__, + "activation": self.activation.__name__, + "border_mode": self.border_mode, + "subsample": self.subsample, + "W_regularizer": self.W_regularizer.get_config() if self.W_regularizer else None, + "b_regularizer": self.b_regularizer.get_config() if self.b_regularizer else None, + "activity_regularizer": self.activity_regularizer.get_config() if self.activity_regularizer else None, + "W_constraint": self.W_constraint.get_config() if self.W_constraint else None, + "b_constraint": self.b_constraint.get_config() if self.b_constraint else None} + base_config = super(Convolution3D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + class _Pooling1D(Layer): '''Abstract class for different pooling 1D layers. ''' @@ -407,6 +597,7 @@ class MaxPooling1D(_Pooling1D): border_mode: 'valid' or 'same'. Note: 'same' will only work with TensorFlow for the time being. ''' + def __init__(self, pool_length=2, stride=None, border_mode='valid', **kwargs): super(MaxPooling1D, self).__init__(pool_length, stride, @@ -434,6 +625,7 @@ class AveragePooling1D(_Pooling1D): border_mode: 'valid' or 'same'. Note: 'same' will only work with TensorFlow for the time being. ''' + def __init__(self, pool_length=2, stride=None, border_mode='valid', **kwargs): super(AveragePooling1D, self).__init__(pool_length, stride, @@ -535,6 +727,7 @@ class MaxPooling2D(_Pooling2D): dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3. ''' + def __init__(self, pool_size=(2, 2), strides=None, border_mode='valid', dim_ordering='th', **kwargs): super(MaxPooling2D, self).__init__(pool_size, strides, border_mode, @@ -572,6 +765,7 @@ class AveragePooling2D(_Pooling2D): dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3. ''' + def __init__(self, pool_size=(2, 2), strides=None, border_mode='valid', dim_ordering='th', **kwargs): super(AveragePooling2D, self).__init__(pool_size, strides, border_mode, @@ -584,6 +778,158 @@ def _pooling_function(self, inputs, pool_size, strides, return output +class _Pooling3D(Layer): + '''Abstract class for different pooling 3D layers. + ''' + input_ndim = 5 + + def __init__(self, pool_size=(2, 2, 2), strides=None, border_mode='valid', + dim_ordering='th', **kwargs): + super(_Pooling3D, self).__init__(**kwargs) + self.input = K.placeholder(ndim=5) + self.pool_size = tuple(pool_size) + if strides is None: + strides = self.pool_size + self.strides = tuple(strides) + assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}' + self.border_mode = border_mode + assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}' + self.dim_ordering = dim_ordering + + @property + def output_shape(self): + input_shape = self.input_shape + if self.dim_ordering == 'th': + len_dim1 = input_shape[2] + len_dim2 = input_shape[3] + len_dim3 = input_shape[4] + elif self.dim_ordering == 'tf': + len_dim1 = input_shape[1] + len_dim2 = input_shape[2] + len_dim3 = input_shape[3] + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + len_dim1 = conv_output_length(len_dim1, self.pool_size[0], + self.border_mode, self.strides[0]) + len_dim2 = conv_output_length(len_dim2, self.pool_size[1], + self.border_mode, self.strides[1]) + len_dim3 = conv_output_length(len_dim3, self.pool_size[2], + self.border_mode, self.strides[2]) + + if self.dim_ordering == 'th': + return (input_shape[0], input_shape[1], len_dim1, len_dim2, len_dim3) + elif self.dim_ordering == 'tf': + return (input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + def _pooling_function(self, inputs, pool_size, strides, + border_mode, dim_ordering): + raise NotImplementedError + + def get_output(self, train=False): + X = self.get_input(train) + output = self._pooling_function(inputs=X, pool_size=self.pool_size, + strides=self.strides, + border_mode=self.border_mode, + dim_ordering=self.dim_ordering) + return output + + def get_config(self): + config = {'name': self.__class__.__name__, + 'pool_size': self.pool_size, + 'border_mode': self.border_mode, + 'strides': self.strides, + 'dim_ordering': self.dim_ordering} + base_config = super(_Pooling3D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class MaxPooling3D(_Pooling3D): + '''Max pooling operation for 3D data (spatial or spatio-temporal). + + Note: this layer will only work with Theano for the time being. + + # Input shape + 5D tensor with shape: + `(samples, channels, len_pool_dim1, len_pool_dim2, len_pool_dim3)` if dim_ordering='th' + or 5D tensor with shape: + `(samples, len_pool_dim1, len_pool_dim2, len_pool_dim3, channels)` if dim_ordering='tf'. + + # Output shape + 5D tensor with shape: + `(nb_samples, channels, pooled_dim1, pooled_dim2, pooled_dim3)` if dim_ordering='th' + or 5D tensor with shape: + `(samples, pooled_dim1, pooled_dim2, pooled_dim3, channels)` if dim_ordering='tf'. + + # Arguments + pool_size: tuple of 3 integers, + factors by which to downscale (dim1, dim2, dim3). + (2, 2, 2) will halve the size of the 3D input in each dimension. + strides: tuple of 3 integers, or None. Strides values. + border_mode: 'valid' or 'same'. + dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension + (the depth) is at index 1, in 'tf' mode is it at index 4. + ''' + + def __init__(self, pool_size=(2, 2, 2), strides=None, border_mode='valid', + dim_ordering='th', **kwargs): + if K._BACKEND != 'theano': + raise Exception(self.__class__.__name__ + + ' is currently only working with Theano backend.') + super(MaxPooling3D, self).__init__(pool_size, strides, border_mode, + dim_ordering, **kwargs) + + def _pooling_function(self, inputs, pool_size, strides, + border_mode, dim_ordering): + output = K.pool3d(inputs, pool_size, strides, + border_mode, dim_ordering, pool_mode='max') + return output + + +class AveragePooling3D(_Pooling3D): + '''Average pooling operation for 3D data (spatial or spatio-temporal). + + Note: this layer will only work with Theano for the time being. + + # Input shape + 5D tensor with shape: + `(samples, channels, len_pool_dim1, len_pool_dim2, len_pool_dim3)` if dim_ordering='th' + or 5D tensor with shape: + `(samples, len_pool_dim1, len_pool_dim2, len_pool_dim3, channels)` if dim_ordering='tf'. + + # Output shape + 5D tensor with shape: + `(nb_samples, channels, pooled_dim1, pooled_dim2, pooled_dim3)` if dim_ordering='th' + or 5D tensor with shape: + `(samples, pooled_dim1, pooled_dim2, pooled_dim3, channels)` if dim_ordering='tf'. + + # Arguments + pool_size: tuple of 3 integers, + factors by which to downscale (dim1, dim2, dim3). + (2, 2, 2) will halve the size of the 3D input in each dimension. + strides: tuple of 3 integers, or None. Strides values. + border_mode: 'valid' or 'same'. + dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension + (the depth) is at index 1, in 'tf' mode is it at index 4. + ''' + + def __init__(self, pool_size=(2, 2, 2), strides=None, border_mode='valid', + dim_ordering='th', **kwargs): + if K._BACKEND != 'theano': + raise Exception(self.__class__.__name__ + + ' is currently only working with Theano backend.') + super(AveragePooling3D, self).__init__(pool_size, strides, border_mode, + dim_ordering, **kwargs) + + def _pooling_function(self, inputs, pool_size, strides, + border_mode, dim_ordering): + output = K.pool3d(inputs, pool_size, strides, + border_mode, dim_ordering, pool_mode='avg') + return output + + class UpSampling1D(Layer): '''Repeat each temporal step `length` times along the time axis. @@ -679,6 +1025,72 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) +class UpSampling3D(Layer): + '''Repeat the first, second and third dimension of the data + by size[0], size[1] and size[2] respectively. + + Note: this layer will only work with Theano for the time being. + + # Input shape + 5D tensor with shape: + `(samples, channels, dim1, dim2, dim3)` if dim_ordering='th' + or 5D tensor with shape: + `(samples, dim1, dim2, dim3, channels)` if dim_ordering='tf'. + + # Output shape + 5D tensor with shape: + `(samples, channels, upsampled_dim1, upsampled_dim2, upsampled_dim3)` if dim_ordering='th' + or 5D tensor with shape: + `(samples, upsampled_dim1, upsampled_dim2, upsampled_dim3, channels)` if dim_ordering='tf'. + + # Arguments + size: tuple of 3 integers. The upsampling factors for dim1, dim2 and dim3. + dim_ordering: 'th' or 'tf'. + In 'th' mode, the channels dimension (the depth) + is at index 1, in 'tf' mode is it at index 4. + ''' + input_ndim = 5 + + def __init__(self, size=(2, 2, 2), dim_ordering='th', **kwargs): + if K._BACKEND != 'theano': + raise Exception(self.__class__.__name__ + + ' is currently only working with Theano backend.') + super(UpSampling3D, self).__init__(**kwargs) + self.input = K.placeholder(ndim=5) + self.size = tuple(size) + assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}' + self.dim_ordering = dim_ordering + + @property + def output_shape(self): + input_shape = self.input_shape + if self.dim_ordering == 'th': + return (input_shape[0], + input_shape[1], + self.size[0] * input_shape[2], + self.size[1] * input_shape[3], + self.size[2] * input_shape[4]) + elif self.dim_ordering == 'tf': + return (input_shape[0], + self.size[0] * input_shape[1], + self.size[1] * input_shape[2], + self.size[2] * input_shape[3], + input_shape[4]) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + def get_output(self, train=False): + X = self.get_input(train) + return K.resize_volumes(X, self.size[0], self.size[1], self.size[2], + self.dim_ordering) + + def get_config(self): + config = {'name': self.__class__.__name__, + 'size': self.size} + base_config = super(UpSampling3D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + class ZeroPadding1D(Layer): '''Zero-padding layer for 1D input (e.g. temporal sequence). @@ -769,3 +1181,63 @@ def get_config(self): 'padding': self.padding} base_config = super(ZeroPadding2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +class ZeroPadding3D(Layer): + '''Zero-padding layer for 3D data (spatial or spatio-temporal). + + Note: this layer will only work with Theano for the time being. + + # Input shape + 5D tensor with shape: + (samples, depth, first_axis_to_pad, second_axis_to_pad, third_axis_to_pad) + + # Output shape + 5D tensor with shape: + (samples, depth, first_padded_axis, second_padded_axis, third_axis_to_pad) + + # Arguments + padding: tuple of int (length 3) + How many zeros to add at the beginning and end of + the 3 padding dimensions (axis 3, 4 and 5). + ''' + input_ndim = 5 + + def __init__(self, padding=(1, 1, 1), dim_ordering='th', **kwargs): + if K._BACKEND != 'theano': + raise Exception(self.__class__.__name__ + + ' is currently only working with Theano backend.') + super(ZeroPadding3D, self).__init__(**kwargs) + self.padding = tuple(padding) + self.input = K.placeholder(ndim=5) + assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}' + self.dim_ordering = dim_ordering + + @property + def output_shape(self): + input_shape = self.input_shape + if self.dim_ordering == 'th': + return (input_shape[0], + input_shape[1], + input_shape[2] + 2 * self.padding[0], + input_shape[3] + 2 * self.padding[1], + input_shape[4] + 2 * self.padding[2]) + elif self.dim_ordering == 'tf': + return (input_shape[0], + input_shape[1] + 2 * self.padding[0], + input_shape[2] + 2 * self.padding[1], + input_shape[3] + 2 * self.padding[2], + input_shape[4]) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + def get_output(self, train=False): + X = self.get_input(train) + return K.spatial_3d_padding(X, padding=self.padding, + dim_ordering=self.dim_ordering) + + def get_config(self): + config = {'name': self.__class__.__name__, + 'padding': self.padding} + base_config = super(ZeroPadding3D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tests/keras/layers/test_convolutional.py b/tests/keras/layers/test_convolutional.py index fc67f6901337..cc32cc4935eb 100644 --- a/tests/keras/layers/test_convolutional.py +++ b/tests/keras/layers/test_convolutional.py @@ -204,6 +204,92 @@ def test_averagepooling_2d(): layer.get_config() +@pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend") +def test_convolution_3d(): + nb_samples = 8 + nb_filter = 9 + stack_size = 7 + len_conv_dim1 = 2 + len_conv_dim2 = 10 + len_conv_dim3 = 6 + + input_len_dim1 = 10 + input_len_dim2 = 11 + input_len_dim3 = 12 + + weights_in = [np.ones((nb_filter, stack_size, len_conv_dim1, len_conv_dim2, len_conv_dim3)), + np.ones(nb_filter)] + + input = np.ones((nb_samples, stack_size, input_len_dim1, + input_len_dim2, input_len_dim3)) + for weight in [None, weights_in]: + for border_mode in ['same', 'valid']: + for subsample in [(1, 1, 1), (2, 2, 2)]: + if border_mode == 'same' and subsample != (1, 1, 1): + continue + for W_regularizer in [None, 'l2']: + for b_regularizer in [None, 'l2']: + for act_regularizer in [None, 'l2']: + layer = convolutional.Convolution3D( + nb_filter, len_conv_dim1, len_conv_dim2, len_conv_dim3, + weights=weight, + border_mode=border_mode, + W_regularizer=W_regularizer, + b_regularizer=b_regularizer, + activity_regularizer=act_regularizer, + subsample=subsample, + input_shape=(stack_size, None, None, None)) + + layer.input = K.variable(input) + for train in [True, False]: + out = K.eval(layer.get_output(train)) + if border_mode == 'same' and subsample == (1, 1, 1): + assert out.shape[2:] == input.shape[2:] + layer.get_config() + + +@pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend") +def test_maxpooling_3d(): + nb_samples = 9 + stack_size = 7 + input_len_dim1 = 10 + input_len_dim2 = 11 + input_len_dim3 = 12 + pool_size = (3, 3, 3) + + input = np.ones((nb_samples, stack_size, input_len_dim1, + input_len_dim2, input_len_dim3)) + for strides in [(1, 1, 1), (2, 2, 2)]: + layer = convolutional.MaxPooling3D(strides=strides, + border_mode='valid', + pool_size=pool_size) + layer.input = K.variable(input) + for train in [True, False]: + K.eval(layer.get_output(train)) + layer.get_config() + + +@pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend") +def test_averagepooling_3d(): + nb_samples = 9 + stack_size = 7 + input_len_dim1 = 10 + input_len_dim2 = 11 + input_len_dim3 = 12 + pool_size = (3, 3, 3) + + input = np.ones((nb_samples, stack_size, input_len_dim1, + input_len_dim2, input_len_dim3)) + for strides in [(1, 1, 1), (2, 2, 2)]: + layer = convolutional.AveragePooling3D(strides=strides, + border_mode='valid', + pool_size=pool_size) + layer.input = K.variable(input) + for train in [True, False]: + K.eval(layer.get_output(train)) + layer.get_config() + + def test_zero_padding_2d(): nb_samples = 9 stack_size = 7 @@ -222,6 +308,28 @@ def test_zero_padding_2d(): layer.get_config() +@pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend") +def test_zero_padding_3d(): + nb_samples = 9 + stack_size = 7 + input_len_dim1 = 10 + input_len_dim2 = 11 + input_len_dim3 = 12 + + input = np.ones((nb_samples, stack_size, input_len_dim1, + input_len_dim2, input_len_dim3)) + layer = convolutional.ZeroPadding3D(padding=(2, 2, 2)) + layer.input = K.variable(input) + for train in [True, False]: + out = K.eval(layer.get_output(train)) + for offset in [0, 1, -1, -2]: + assert_allclose(out[:, :, offset, :, :], 0.) + assert_allclose(out[:, :, :, offset, :], 0.) + assert_allclose(out[:, :, :, :, offset], 0.) + assert_allclose(out[:, :, 2:-2, 2:-2, 2:-2], 1.) + layer.get_config() + + def test_upsampling_1d(): nb_samples = 9 nb_steps = 7 @@ -253,29 +361,76 @@ def test_upsampling_2d(): for length_row in [2, 3, 9]: for length_col in [2, 3, 9]: - layer = convolutional.UpSampling2D( - size=(length_row, length_col), + layer = convolutional.UpSampling2D( + size=(length_row, length_col), + input_shape=input.shape[1:], + dim_ordering=dim_ordering) + layer.input = K.variable(input) + for train in [True, False]: + out = K.eval(layer.get_output(train)) + if dim_ordering == 'th': + assert out.shape[2] == length_row * input_nb_row + assert out.shape[3] == length_col * input_nb_col + else: # tf + assert out.shape[1] == length_row * input_nb_row + assert out.shape[2] == length_col * input_nb_col + + # compare with numpy + if dim_ordering == 'th': + expected_out = np.repeat(input, length_row, axis=2) + expected_out = np.repeat(expected_out, length_col, axis=3) + else: # tf + expected_out = np.repeat(input, length_row, axis=1) + expected_out = np.repeat(expected_out, length_col, axis=2) + + assert_allclose(out, expected_out) + + layer.get_config() + + +@pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend") +def test_upsampling_3d(): + nb_samples = 9 + stack_size = 7 + input_len_dim1 = 10 + input_len_dim2 = 11 + input_len_dim3 = 12 + + for dim_ordering in ['th', 'tf']: + if dim_ordering == 'th': + input = np.random.rand(nb_samples, stack_size, input_len_dim1, input_len_dim2, + input_len_dim3) + else: # tf + input = np.random.rand(nb_samples, input_len_dim1, input_len_dim2, input_len_dim3, + stack_size) + for length_dim1 in [2, 3, 9]: + for length_dim2 in [2, 3, 9]: + for length_dim3 in [2, 3, 9]: + layer = convolutional.UpSampling3D( + size=(length_dim1, length_dim2, length_dim3), input_shape=input.shape[1:], dim_ordering=dim_ordering) layer.input = K.variable(input) for train in [True, False]: out = K.eval(layer.get_output(train)) if dim_ordering == 'th': - assert out.shape[2] == length_row * input_nb_row - assert out.shape[3] == length_col * input_nb_col + assert out.shape[2] == length_dim1 * input_len_dim1 + assert out.shape[3] == length_dim2 * input_len_dim2 + assert out.shape[4] == length_dim3 * input_len_dim3 else: # tf - assert out.shape[1] == length_row * input_nb_row - assert out.shape[2] == length_col * input_nb_col + assert out.shape[1] == length_dim1 * input_len_dim1 + assert out.shape[2] == length_dim2 * input_len_dim2 + assert out.shape[3] == length_dim3 * input_len_dim3 # compare with numpy if dim_ordering == 'th': - expected_out = np.repeat(input, length_row, axis=2) - expected_out = np.repeat(expected_out, length_col, - axis=3) + expected_out = np.repeat(input, length_dim1, axis=2) + expected_out = np.repeat(expected_out, length_dim2, axis=3) + expected_out = np.repeat(expected_out, length_dim3, axis=4) else: # tf - expected_out = np.repeat(input, length_row, axis=1) - expected_out = np.repeat(expected_out, length_col, - axis=2) + expected_out = np.repeat(input, length_dim1, axis=1) + expected_out = np.repeat(expected_out, length_dim2, axis=2) + expected_out = np.repeat(expected_out, length_dim3, axis=3) assert_allclose(out, expected_out)