Skip to content

Commit 4291077

Browse files
committed
Remove duplicate code and make it support the new dataset loading fct.
1 parent 8a9a83d commit 4291077

File tree

1 file changed

+6
-34
lines changed

1 file changed

+6
-34
lines changed

code/logistic_cg.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
import theano
4949
import theano.tensor as T
5050

51+
from logistic_sgd import load_data
52+
5153

5254
class LogisticRegression(object):
5355
"""Multi-class Logistic Regression Class
@@ -148,41 +150,11 @@ def cg_optimization_mnist(n_epochs=50, mnist_pkl_gz='mnist.pkl.gz'):
148150
#############
149151
# LOAD DATA #
150152
#############
151-
print '... loading data'
152-
153-
# Load the dataset
154-
f = gzip.open(mnist_pkl_gz, 'rb')
155-
train_set, valid_set, test_set = cPickle.load(f)
156-
f.close()
153+
datasets = load_data(mnist_pkl_gz)
157154

158-
def shared_dataset(data_xy, borrow=True):
159-
""" Function that loads the dataset into shared variables
160-
161-
The reason we store our dataset in shared variables is to allow
162-
Theano to copy it into the GPU memory (when code is run on GPU).
163-
Since copying data into the GPU is slow, copying a minibatch everytime
164-
is needed (the default behaviour if the data is not in a shared
165-
variable) would lead to a large decrease in performance.
166-
"""
167-
data_x, data_y = data_xy
168-
shared_x = theano.shared(numpy.asarray(data_x,
169-
dtype=theano.config.floatX),
170-
borrow=borrow)
171-
shared_y = theano.shared(numpy.asarray(data_y,
172-
dtype=theano.config.floatX),
173-
borrow=borrow)
174-
# When storing data on the GPU it has to be stored as floats
175-
# therefore we will store the labels as ``floatX`` as well
176-
# (``shared_y`` does exactly that). But during our computations
177-
# we need them as ints (we use labels as index, and if they are
178-
# floats it doesn't make sense) therefore instead of returning
179-
# ``shared_y`` we will have to cast it to int. This little hack
180-
# lets ous get around this issue
181-
return shared_x, T.cast(shared_y, 'int32')
182-
183-
test_set_x, test_set_y = shared_dataset(test_set)
184-
valid_set_x, valid_set_y = shared_dataset(valid_set)
185-
train_set_x, train_set_y = shared_dataset(train_set)
155+
train_set_x, train_set_y = datasets[0]
156+
valid_set_x, valid_set_y = datasets[1]
157+
test_set_x, test_set_y = datasets[2]
186158

187159
batch_size = 600 # size of the minibatch
188160

0 commit comments

Comments
 (0)