Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions examples/shapes_3d_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
__author__ = 'Minhaz Palasara'

from keras.datasets import shapes_3d
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution3D, MaxPooling3D
from keras.optimizers import SGD, RMSprop
from keras.utils import np_utils, generic_utils
import theano


"""
To classify/track 3D shapes, such as human hands (http://www.dbs.ifi.lmu.de/~yu_k/icml2010_3dcnn.pdf),
we first need to find a distinct set of features. Specifically for 3D shapes, robust classification can be done using
3D features.

Features can be extracted by applying a 3D filters. We can auto learn these filters using 3D deep learning.

This example trains a simple network for classifying 3D shapes (Spheres, and Cubes).

GPU run command:
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python shapes_3d_cnn.py

CPU run command:
THEANO_FLAGS=mode=FAST_RUN,device=cpu,floatX=float32 python shapes_3d_cnn.py

For 4000 training samples and 1000 test samples.
90% accuracy reached in 10 epochs, 37 seconds/epoch on GTX Titan
"""

# Data Generation parameters
test_split = 0.2
dataset_size = 5000
patch_size = 32

(X_train, Y_train),(X_test, Y_test) = shapes_3d.load_data(test_split=test_split,
dataset_size=dataset_size,
patch_size=patch_size)

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# CNN Training parameters
batch_size = 128
nb_classes = 2
nb_epoch = 100

# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(Y_train, nb_classes)
Y_test = np_utils.to_categorical(Y_test, nb_classes)

# number of convolutional filters to use at each layer
nb_filters = [16, 32]

# level of pooling to perform at each layer (POOL x POOL)
nb_pool = [3, 3]

# level of convolution to perform at each layer (CONV x CONV)
nb_conv = [7, 3]

model = Sequential()
model.add(Convolution3D(nb_filters[0],nb_depth=nb_conv[0], nb_row=nb_conv[0], nb_col=nb_conv[0], border_mode='valid',
input_shape=(1, patch_size, patch_size, patch_size), activation='relu'))
model.add(MaxPooling3D(pool_size=(nb_pool[0], nb_pool[0], nb_pool[0])))
model.add(Dropout(0.5))
model.add(Convolution3D(nb_filters[1],nb_depth=nb_conv[1], nb_row=nb_conv[1], nb_col=nb_conv[1], border_mode='valid',
activation='relu'))
model.add(MaxPooling3D(pool_size=(nb_pool[1], nb_pool[1], nb_pool[1])))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(16, init='normal', activation='relu'))
model.add(Dense(nb_classes, init='normal'))
model.add(Activation('softmax'))

sgd = RMSprop(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd)

model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=2,
validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, batch_size=batch_size, show_accuracy=True)
print('Test score:', score[0])
print('Test accuracy:', score[1])


104 changes: 104 additions & 0 deletions keras/datasets/shapes_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
__author__ = 'Jake Varley'

import numpy as np
import math
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def load_data(test_split=0.2, dataset_size=5000, patch_size=32):
"""
Description
-----------
creates a dataset with total "dataset_size" samples.
Class of a sample (sphere, and cube) is chosen at random with equal probability.
Based on the "test_split", the dataset is divided in test and train subsets.
The "patch_size" defines the size of a 3D array for storing voxel.


Output shape
------------ size
(4D array, 1D array), (4D array, 1D array) ====> (Train_Voxels, Train_Lables), (Test_Voxels, Test_Labels)
Train and test split of total 'dataset_size' voxels with labels

Arguments
------------
test_split: float
percentage of total samples for training

dataset_size: int
total number of samples

patch_size:
size of each dimension of a 3D array to store voxel
"""

if patch_size < 10:
raise NotImplementedError

num_labels = 2

# Using same probability for each class
geometry_types = np.random.randint(0, num_labels, dataset_size)
random.shuffle(geometry_types)

# Getting the training set
y_train = geometry_types[0:abs((1-test_split)*dataset_size)]
x_train = __generate_solid_figures(geometry_types=y_train, patch_size=patch_size)

# Getting the testing set
y_test = geometry_types[abs((1-test_split)*dataset_size):]
x_test = __generate_solid_figures(geometry_types=y_test, patch_size=patch_size)

return (x_train, y_train),(x_test, y_test)

def __generate_solid_figures(geometry_types, patch_size):

"""
Output shape
------------
4D array (samples, patch_size(Z), patch_size(X), patch_size(Y))
Voxel for each label passed as input through geometry_types

Arguments
geometry_types: numpy array (samples, 1)
An array of class labels (0 for sphere, 1 for cube)
patch_size: int
Size of 3d array to store voxel

"""
shapes_no = geometry_types.shape[0]

# Assuming data is centered
(x0, y0, z0) = ((patch_size-1)/2,)*3

# Allocate 3D data array, data is in cube(all dimensions are same)
solid_figures = np.zeros((len(geometry_types), 1, patch_size,
patch_size, patch_size))
for i in range(0, len(geometry_types)):
# # radius is a random number in [3, self.patch_size/2)
radius = (patch_size/2 - 3) * np.random.rand() + 3

# bounding box values for optimization
x_min = int(max(math.ceil(x0-radius), 0))
y_min = int(max(math.ceil(y0-radius), 0))
z_min = int(max(math.ceil(z0-radius), 0))
x_max = int(min(math.floor(x0+radius), patch_size-1))
y_max = int(min(math.floor(y0+radius), patch_size-1))
z_max = int(min(math.floor(z0+radius), patch_size-1))

if geometry_types[i] == 0: #Sphere
radius_squared = radius**2
for z in xrange(z_min, z_max+1):
for x in xrange(x_min, x_max+1):
for y in xrange(y_min, y_max+1):
if (x-x0)**2 + (y-y0)**2 + (z-z0)**2 <= radius_squared:
# inside the sphere
solid_figures[i, 0, z, x, y] = 1
elif geometry_types[i] == 1: #Cube
solid_figures[i, 0, z_min:z_max+1, x_min:x_max+1, y_min:y_max+1] = 1
else:
raise NotImplementedError

return solid_figures
Loading