diff --git a/examples/shapes_3d_cnn.py b/examples/shapes_3d_cnn.py new file mode 100644 index 000000000000..37decfec376e --- /dev/null +++ b/examples/shapes_3d_cnn.py @@ -0,0 +1,86 @@ +__author__ = 'Minhaz Palasara' + +from keras.datasets import shapes_3d +from keras.preprocessing.image import ImageDataGenerator +from keras.models import Sequential +from keras.layers.core import Dense, Dropout, Activation, Flatten +from keras.layers.convolutional import Convolution3D, MaxPooling3D +from keras.optimizers import SGD, RMSprop +from keras.utils import np_utils, generic_utils +import theano + + +""" + To classify/track 3D shapes, such as human hands (http://www.dbs.ifi.lmu.de/~yu_k/icml2010_3dcnn.pdf), + we first need to find a distinct set of features. Specifically for 3D shapes, robust classification can be done using + 3D features. + + Features can be extracted by applying a 3D filters. We can auto learn these filters using 3D deep learning. + + This example trains a simple network for classifying 3D shapes (Spheres, and Cubes). + + GPU run command: + THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python shapes_3d_cnn.py + + CPU run command: + THEANO_FLAGS=mode=FAST_RUN,device=cpu,floatX=float32 python shapes_3d_cnn.py + + For 4000 training samples and 1000 test samples. + 90% accuracy reached in 10 epochs, 37 seconds/epoch on GTX Titan +""" + +# Data Generation parameters +test_split = 0.2 +dataset_size = 5000 +patch_size = 32 + +(X_train, Y_train),(X_test, Y_test) = shapes_3d.load_data(test_split=test_split, + dataset_size=dataset_size, + patch_size=patch_size) + +print('X_train shape:', X_train.shape) +print(X_train.shape[0], 'train samples') +print(X_test.shape[0], 'test samples') + +# CNN Training parameters +batch_size = 128 +nb_classes = 2 +nb_epoch = 100 + +# convert class vectors to binary class matrices +Y_train = np_utils.to_categorical(Y_train, nb_classes) +Y_test = np_utils.to_categorical(Y_test, nb_classes) + +# number of convolutional filters to use at each layer +nb_filters = [16, 32] + +# level of pooling to perform at each layer (POOL x POOL) +nb_pool = [3, 3] + +# level of convolution to perform at each layer (CONV x CONV) +nb_conv = [7, 3] + +model = Sequential() +model.add(Convolution3D(nb_filters[0],nb_depth=nb_conv[0], nb_row=nb_conv[0], nb_col=nb_conv[0], border_mode='valid', + input_shape=(1, patch_size, patch_size, patch_size), activation='relu')) +model.add(MaxPooling3D(pool_size=(nb_pool[0], nb_pool[0], nb_pool[0]))) +model.add(Dropout(0.5)) +model.add(Convolution3D(nb_filters[1],nb_depth=nb_conv[1], nb_row=nb_conv[1], nb_col=nb_conv[1], border_mode='valid', + activation='relu')) +model.add(MaxPooling3D(pool_size=(nb_pool[1], nb_pool[1], nb_pool[1]))) +model.add(Flatten()) +model.add(Dropout(0.5)) +model.add(Dense(16, init='normal', activation='relu')) +model.add(Dense(nb_classes, init='normal')) +model.add(Activation('softmax')) + +sgd = RMSprop(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) +model.compile(loss='categorical_crossentropy', optimizer=sgd) + +model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=2, + validation_data=(X_test, Y_test)) +score = model.evaluate(X_test, Y_test, batch_size=batch_size, show_accuracy=True) +print('Test score:', score[0]) +print('Test accuracy:', score[1]) + + diff --git a/keras/datasets/shapes_3d.py b/keras/datasets/shapes_3d.py new file mode 100644 index 000000000000..b46621281f75 --- /dev/null +++ b/keras/datasets/shapes_3d.py @@ -0,0 +1,104 @@ +__author__ = 'Jake Varley' + +import numpy as np +import math +import random +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +import numpy as np + +def load_data(test_split=0.2, dataset_size=5000, patch_size=32): + """ + Description + ----------- + creates a dataset with total "dataset_size" samples. + Class of a sample (sphere, and cube) is chosen at random with equal probability. + Based on the "test_split", the dataset is divided in test and train subsets. + The "patch_size" defines the size of a 3D array for storing voxel. + + + Output shape + ------------ size + (4D array, 1D array), (4D array, 1D array) ====> (Train_Voxels, Train_Lables), (Test_Voxels, Test_Labels) + Train and test split of total 'dataset_size' voxels with labels + + Arguments + ------------ + test_split: float + percentage of total samples for training + + dataset_size: int + total number of samples + + patch_size: + size of each dimension of a 3D array to store voxel + """ + + if patch_size < 10: + raise NotImplementedError + + num_labels = 2 + + # Using same probability for each class + geometry_types = np.random.randint(0, num_labels, dataset_size) + random.shuffle(geometry_types) + + # Getting the training set + y_train = geometry_types[0:abs((1-test_split)*dataset_size)] + x_train = __generate_solid_figures(geometry_types=y_train, patch_size=patch_size) + + # Getting the testing set + y_test = geometry_types[abs((1-test_split)*dataset_size):] + x_test = __generate_solid_figures(geometry_types=y_test, patch_size=patch_size) + + return (x_train, y_train),(x_test, y_test) + +def __generate_solid_figures(geometry_types, patch_size): + + """ + Output shape + ------------ + 4D array (samples, patch_size(Z), patch_size(X), patch_size(Y)) + Voxel for each label passed as input through geometry_types + + Arguments + geometry_types: numpy array (samples, 1) + An array of class labels (0 for sphere, 1 for cube) + patch_size: int + Size of 3d array to store voxel + + """ + shapes_no = geometry_types.shape[0] + + # Assuming data is centered + (x0, y0, z0) = ((patch_size-1)/2,)*3 + + # Allocate 3D data array, data is in cube(all dimensions are same) + solid_figures = np.zeros((len(geometry_types), 1, patch_size, + patch_size, patch_size)) + for i in range(0, len(geometry_types)): + # # radius is a random number in [3, self.patch_size/2) + radius = (patch_size/2 - 3) * np.random.rand() + 3 + + # bounding box values for optimization + x_min = int(max(math.ceil(x0-radius), 0)) + y_min = int(max(math.ceil(y0-radius), 0)) + z_min = int(max(math.ceil(z0-radius), 0)) + x_max = int(min(math.floor(x0+radius), patch_size-1)) + y_max = int(min(math.floor(y0+radius), patch_size-1)) + z_max = int(min(math.floor(z0+radius), patch_size-1)) + + if geometry_types[i] == 0: #Sphere + radius_squared = radius**2 + for z in xrange(z_min, z_max+1): + for x in xrange(x_min, x_max+1): + for y in xrange(y_min, y_max+1): + if (x-x0)**2 + (y-y0)**2 + (z-z0)**2 <= radius_squared: + # inside the sphere + solid_figures[i, 0, z, x, y] = 1 + elif geometry_types[i] == 1: #Cube + solid_figures[i, 0, z_min:z_max+1, x_min:x_max+1, y_min:y_max+1] = 1 + else: + raise NotImplementedError + + return solid_figures diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index bb86aa7bbd01..95023a8197df 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -4,7 +4,7 @@ import theano import theano.tensor as T from theano.tensor.signal import downsample - +from theano.tensor.nnet import conv3d2d from .. import activations, initializations, regularizers, constraints from ..utils.theano_utils import shared_zeros, on_gpu from ..layers.core import Layer @@ -285,6 +285,140 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) +class Convolution3D(Layer): + input_ndim = 5 + + def __init__(self, nb_filter, nb_depth, nb_row, nb_col, + init='glorot_uniform', activation='linear', weights=None, + border_mode='valid', subsample=(1, 1, 1), + W_regularizer=None, b_regularizer=None, activity_regularizer=None, + W_constraint=None, b_constraint=None, **kwargs): + + if border_mode not in {'valid', 'full', 'same'}: + raise Exception('Invalid border mode for Convolution3D:', border_mode) + self.nb_filter = nb_filter + self.nb_depth = nb_depth + self.nb_row = nb_row + self.nb_col = nb_col + self.init = initializations.get(init) + self.activation = activations.get(activation) + self.border_mode = border_mode + self.subsample = tuple(subsample) + + self.W_regularizer = regularizers.get(W_regularizer) + self.b_regularizer = regularizers.get(b_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) + + self.W_constraint = constraints.get(W_constraint) + self.b_constraint = constraints.get(b_constraint) + self.constraints = [self.W_constraint, self.b_constraint] + + self.initial_weights = weights + super(Convolution3D, self).__init__(**kwargs) + + def build(self): + stack_size = self.input_shape[1] + dtensor5 = T.TensorType('float32', (0,)*5) + self.input = dtensor5() + self.W_shape = (self.nb_filter, stack_size, self.nb_depth, self.nb_row, self.nb_col) + self.W = self.init(self.W_shape) + self.b = shared_zeros((self.nb_filter,)) + self.params = [self.W, self.b] + self.regularizers = [] + + if self.W_regularizer: + self.W_regularizer.set_param(self.W) + self.regularizers.append(self.W_regularizer) + + if self.b_regularizer: + self.b_regularizer.set_param(self.b) + self.regularizers.append(self.b_regularizer) + + if self.activity_regularizer: + self.activity_regularizer.set_layer(self) + self.regularizers.append(self.activity_regularizer) + + if self.initial_weights is not None: + self.set_weights(self.initial_weights) + del self.initial_weights + + @property + def output_shape(self): + input_shape = self.input_shape + depth = input_shape[2] + rows = input_shape[3] + cols = input_shape[4] + depth = conv_output_length(depth, self.nb_depth, self.border_mode, self.subsample[0]) + rows = conv_output_length(rows, self.nb_row, self.border_mode, self.subsample[1]) + cols = conv_output_length(cols, self.nb_col, self.border_mode, self.subsample[2]) + return (input_shape[0], self.nb_filter, depth, rows, cols) + + def get_output(self, train): + X = self.get_input(train) + border_mode = self.border_mode + + # Both conv3d2d.conv3d and nnet.conv3D only support the 'valid' border mode + if border_mode != 'valid': + if border_mode == 'same': + assert(self.subsample == (1, 1, 1)) + pad_z = (self.nb_depth - self.subsample[0]) + pad_x = (self.nb_row - self.subsample[1]) + pad_y = (self.nb_col - self.subsample[2]) + else: #full + pad_z = (self.nb_depth - 1) * 2 + pad_x = (self.nb_row - 1) * 2 + pad_y = (self.nb_col - 1) * 2 + + input_shape = X.shape + output_shape = (input_shape[0], input_shape[1], + input_shape[2] + pad_z, + input_shape[3] + pad_x, + input_shape[4] + pad_y) + output = T.zeros(output_shape) + indices = (slice(None), slice(None), + slice(pad_z//2, input_shape[2] + pad_z//2), + slice(pad_x//2, input_shape[3] + pad_x//2), + slice(pad_y//2, input_shape[4] + pad_y//2)) + X = T.set_subtensor(output[indices], X) + + + border_mode = 'valid' + + if on_gpu(): + # Shuffle the dimensions as per the input parameter order, restore it once done + conv_out = conv3d2d.conv3d(signals=X.dimshuffle(0, 2, 1, 3, 4), + filters=self.W.dimshuffle(0, 2, 1, 3, 4), + border_mode=border_mode) + + conv_out = conv_out.dimshuffle(0, 2, 1, 3, 4) + else: + # Shuffle the dimensions as per the input parameter order, restore it once done + conv_out = T.nnet.conv3D(V=X.dimshuffle(0, 2, 3, 4, 1), + W=self.W.dimshuffle(0, 2, 3, 4, 1), + b=self.b, d=self.subsample) + conv_out = conv_out.dimshuffle(0, 4, 1, 2, 3) + + output = self.activation(conv_out + self.b.dimshuffle('x', 0, 'x', 'x', 'x')) + return output + + def get_config(self): + return {"name": self.__class__.__name__, + "nb_filter": self.nb_filter, + "stack_size": self.stack_size, + "nb_depth": self.nb_depth, + "nb_row": self.nb_row, + "nb_col": self.nb_col, + "init": self.init.__name__, + "activation": self.activation.__name__, + "border_mode": self.border_mode, + "subsample": self.subsample, + "W_regularizer": self.W_regularizer.get_config() if self.W_regularizer else None, + "b_regularizer": self.b_regularizer.get_config() if self.b_regularizer else None, + "activity_regularizer": self.activity_regularizer.get_config() if self.activity_regularizer else None, + "W_constraint": self.W_constraint.get_config() if self.W_constraint else None, + "b_constraint": self.b_constraint.get_config() if self.b_constraint else None} + + class MaxPooling1D(Layer): input_ndim = 3 @@ -355,6 +489,54 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) +class MaxPooling3D(Layer): + input_ndim = 5 + + def __init__(self, pool_size=(2, 2, 2), stride=None, ignore_border=True, **kwargs): + super(MaxPooling3D, self).__init__(**kwargs) + self.mode = 'max' + self.pool_size = tuple(pool_size) + self.ignore_border = ignore_border + if stride is None: + stride = self.pool_size + self.stride = tuple(stride) + self.ignore_border = ignore_border + + dtensor5 = T.TensorType('float32', (0,)*5) + self.input = dtensor5() + self.params = [] + + @property + def output_shape(self): + input_shape = self.input_shape + depth = pool_output_length(input_shape[2], self.pool_size[0], self.ignore_border, self.stride[0]) + rows = pool_output_length(input_shape[3], self.pool_size[1], self.ignore_border, self.stride[1]) + cols = pool_output_length(input_shape[4], self.pool_size[2], self.ignore_border, self.stride[2]) + return (input_shape[0], input_shape[1], depth, rows, cols) + + def get_output(self, train): + X = self.get_input(train) + + # pooling over X, Z (last two channels) + output = downsample.max_pool_2d(input=X.dimshuffle(0, 1, 4, 3, 2), + ds=(self.pool_size[1], self.pool_size[0]), + ignore_border=self.ignore_border) + + # max_pool_2d X and Y, X constant + output = downsample.max_pool_2d(input=output.dimshuffle(0, 1, 4, 3, 2), + ds=(1, self.pool_size[2]), + ignore_border=self.ignore_border) + return output + + def get_config(self): + return {"name": self.__class__.__name__, + "pool_size": self.pool_size, + "ignore_border": self.ignore_border, + "stride": self.stride} + base_config = super(MaxPooling3D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + class UpSample1D(Layer): input_ndim = 3 @@ -502,3 +684,60 @@ def get_config(self): "padding": self.padding} base_config = super(ZeroPadding2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +class ZeroPadding3D(Layer): + """Zero-padding layer for 3D input (e.g. 3D voxel points of hand). + + Input shape + ----------- + 5D tensor with shape (samples, channels, depth, first_axis_to_pad, second_axis_to_pad) + + Output shape + ------------ + 5D tensor with shape (samples, channels, depth, first_padded_axis, second_padded_axis) + + Arguments + --------- + padding: tuple of int (length 3) + How many zeros to add at the beginning and end of + the 3 padding dimensions (axis 3 and 4 and 5). + """ + input_ndim = 5 + + def __init__(self, padding=(1, 1, 1), **kwargs): + super(ZeroPadding3D, self).__init__(**kwargs) + self.padding = tuple(padding) + dtensor5 = T.TensorType('float32', (0,)*5) + self.input = dtensor5() + + @property + def output_shape(self): + input_shape = self.input_shape + return (input_shape[0], + input_shape[1], + input_shape[2] + 2 * self.padding[0], + input_shape[3] + 2 * self.padding[1], + input_shape[4] + 2 * self.padding[2]) + + def get_output(self, train=False): + X = self.get_input(train) + input_shape = X.shape + output_shape = (input_shape[0], + input_shape[1], + input_shape[2] + 2 * self.padding[0], + input_shape[3] + 2 * self.padding[1], + input_shape[4] + 2 * self.padding[2]) + output = T.zeros(output_shape) + indices = (slice(None), + slice(None), + slice(self.padding[0], input_shape[2] + self.padding[0]), + slice(self.padding[1], input_shape[3] + self.padding[1]), + slice(self.padding[2], input_shape[4] + self.padding[2])) + return T.set_subtensor(output[indices], X) + + def get_config(self): + config = {"name": self.__class__.__name__, + "padding": self.padding} + base_config = super(ZeroPadding3D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file