Skip to content

Commit 47aafdc

Browse files
committed
Merge pull request lisa-lab#1551 from lamblin/conv_nonlinearity_cost
Conv nonlinearity cost
2 parents 74fd21b + cf9c464 commit 47aafdc

File tree

7 files changed

+251
-113
lines changed

7 files changed

+251
-113
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ before_install:
77
- export PATH=/home/travis/miniconda/bin:$PATH
88
- conda update --yes conda
99
install:
10-
- if [[ $TRAVIS_PYTHON_VERSION == '2.6' ]]; then conda create --yes -q -n pyenv python=2.6 mkl pyzmq cython=0.2 pillow numpy=1.6 numpydoc scipy=0.11 pytables=3.0 numexpr=2.2.2 nose=1.1 pyyaml sphinx pyflakes argparse pip matplotlib scikit-learn h5py; fi
10+
- if [[ $TRAVIS_PYTHON_VERSION == '2.6' ]]; then conda create --yes -q -n pyenv python=2.6 mkl pyzmq cython=0.21.1 pillow numpy=1.9.1 numpydoc scipy=0.14.0 pytables=3.1.1 numexpr=2.3.1 nose=1.3.4 pyyaml sphinx pyflakes argparse pip matplotlib scikit-learn h5py; fi
1111
- if [[ $TRAVIS_PYTHON_VERSION == '3.4' ]]; then conda create --yes -q -n pyenv python=3.4 mkl pyzmq cython=0.21.1 pillow numpy=1.9.1 numpydoc scipy=0.14.0 pytables=3.1.1 numexpr=2.3.1 nose=1.3.4 pyyaml sphinx pyflakes pip matplotlib scikit-learn h5py; fi
1212
- source activate pyenv
1313
- pip install -q git+git://git.assembla.com/jobman.git

pylearn2/datasets/tests/test_preprocessing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def test_zca(self):
239239

240240
# Check if preprocessed data matrix is white
241241
assert_allclose(np.cov(preprocessed_X.transpose(),
242-
bias=1), identity, rtol=1e-4)
242+
bias=1), identity, rtol=1e-4, atol=1e-4)
243243

244244
# Check if we obtain correct solution
245245
zca_transformed_X = np.array(
@@ -290,7 +290,8 @@ def test(store_inverse):
290290
fit_preprocessor=True)
291291

292292
preprocessed_X = dataset.get_design_matrix()
293-
assert_allclose(self.X, preprocessor.inverse(preprocessed_X))
293+
assert_allclose(self.X, preprocessor.inverse(preprocessed_X),
294+
atol=5e-5, rtol=1e-5)
294295

295296
test(store_inverse=True)
296297
test(store_inverse=False)

pylearn2/models/mlp.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,6 +2680,27 @@ def get_monitoring_channels_from_state(self, state, target,
26802680

26812681
return rval
26822682

2683+
def cost(self, Y, Y_hat, batch_axis):
2684+
"""
2685+
The cost of outputting Y_hat when the true output is Y.
2686+
2687+
Parameters
2688+
----------
2689+
Y : theano.gof.Variable
2690+
Output of `fprop`
2691+
Y_hat : theano.gof.Variable
2692+
Targets
2693+
batch_axis : integer
2694+
axis representing batch dimension
2695+
2696+
Returns
2697+
-------
2698+
cost : theano.gof.Variable
2699+
0-D tensor describing the cost
2700+
"""
2701+
raise NotImplementedError(
2702+
str(type(self)) + " does not implement cost function.")
2703+
26832704

26842705
class IdentityConvNonlinearity(ConvNonlinearity):
26852706

@@ -2708,6 +2729,15 @@ def get_monitoring_channels_from_state(self,
27082729

27092730
return rval
27102731

2732+
@wraps(ConvNonlinearity.cost, append=True)
2733+
def cost(self, Y, Y_hat, batch_axis):
2734+
"""
2735+
Notes
2736+
-----
2737+
Mean squared error across examples in a batch
2738+
"""
2739+
return T.sum(T.mean(T.sqr(Y-Y_hat), axis=batch_axis))
2740+
27112741

27122742
class RectifierConvNonlinearity(ConvNonlinearity):
27132743

@@ -2820,6 +2850,19 @@ def get_monitoring_channels_from_state(self, state, target,
28202850

28212851
return rval
28222852

2853+
@wraps(ConvNonlinearity.cost, append=True)
2854+
def cost(self, Y, Y_hat, batch_axis):
2855+
"""
2856+
Notes
2857+
-----
2858+
Cost mean across units, mean across batch of KL divergence
2859+
KL(P || Q) where P is defined by Y and Q is defined by Y_hat
2860+
KL(P || Q) = p log p - p log q + (1-p) log (1-p) - (1-p) log (1-q)
2861+
"""
2862+
ave_total = kl(Y=Y, Y_hat=Y_hat, batch_axis=batch_axis)
2863+
ave = ave_total.mean()
2864+
return ave
2865+
28232866

28242867
class TanhConvNonlinearity(ConvNonlinearity):
28252868

@@ -3255,39 +3298,16 @@ def fprop(self, state_below):
32553298

32563299
return p
32573300

3301+
@wraps(Layer.cost, append=True)
32583302
def cost(self, Y, Y_hat):
32593303
"""
3260-
Cost for convnets is hardcoded to be the cost for sigmoids.
3261-
TODO: move the cost into the non-linearity class.
3262-
3263-
Parameters
3264-
----------
3265-
Y : theano.gof.Variable
3266-
Output of `fprop`
3267-
Y_hat : theano.gof.Variable
3268-
Targets
3269-
3270-
Returns
3271-
-------
3272-
cost : theano.gof.Variable
3273-
0-D tensor describing the cost
3274-
32753304
Notes
32763305
-----
3277-
Cost mean across units, mean across batch of KL divergence
3278-
KL(P || Q) where P is defined by Y and Q is defined by Y_hat
3279-
KL(P || Q) = p log p - p log q + (1-p) log (1-p) - (1-p) log (1-q)
3306+
The cost method calls `self.nonlin.cost`
32803307
"""
3281-
assert self.nonlin.non_lin_name == "sigmoid", ("ConvElemwise "
3282-
"supports "
3283-
"cost function "
3284-
"for only "
3285-
"sigmoid layer "
3286-
"for now.")
3308+
32873309
batch_axis = self.output_space.get_batch_axis()
3288-
ave_total = kl(Y=Y, Y_hat=Y_hat, batch_axis=batch_axis)
3289-
ave = ave_total.mean()
3290-
return ave
3310+
return self.nonlin.cost(Y=Y, Y_hat=Y_hat, batch_axis=batch_axis)
32913311

32923312

32933313
class ConvRectifiedLinear(ConvElemwise):
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""
2+
Note: Cost functions are not implemented for RectifierConvNonlinearity,
3+
TanhConvNonlinearity, RectifiedLinear, and Tanh. Here we verify that the
4+
implemented cost functions for convolutional layers give the correct output
5+
by comparing to standard MLP's.
6+
"""
7+
8+
import numpy as np
9+
from numpy.testing import assert_raises
10+
11+
import theano
12+
from theano import config
13+
from theano.tests.unittest_tools import assert_allclose
14+
15+
from pylearn2.models.mlp import MLP
16+
from pylearn2.models.mlp import Sigmoid, Tanh, Linear, RectifiedLinear
17+
from pylearn2.models.mlp import ConvElemwise
18+
from pylearn2.space import Conv2DSpace
19+
from pylearn2.models.mlp import SigmoidConvNonlinearity
20+
from pylearn2.models.mlp import TanhConvNonlinearity
21+
from pylearn2.models.mlp import IdentityConvNonlinearity
22+
from pylearn2.models.mlp import RectifierConvNonlinearity
23+
24+
25+
def check_case(conv_nonlinearity, mlp_nonlinearity, cost_implemented=True):
26+
"""Check that ConvNonLinearity and MLPNonlinearity are consistent.
27+
28+
This is done by building an MLP with a ConvElemwise layer with the
29+
supplied non-linearity, an MLP with a dense layer, and checking that
30+
the outputs (and costs if applicable) are consistent.
31+
32+
Parameters
33+
----------
34+
conv_nonlinearity: instance of `ConvNonlinearity`
35+
The non-linearity to provide to a `ConvElemwise` layer.
36+
37+
mlp_nonlinearity: subclass of `mlp.Linear`
38+
The fully-connected MLP layer (including non-linearity).
39+
40+
check_implemented: bool
41+
If `True`, check that both costs give consistent results.
42+
If `False`, check that both costs raise `NotImplementedError`.
43+
"""
44+
45+
# Create fake data
46+
np.random.seed(12345)
47+
48+
r = 31
49+
s = 21
50+
shape = [r, s]
51+
nvis = r*s
52+
output_channels = 13
53+
batch_size = 103
54+
55+
x = np.random.rand(batch_size, r, s, 1)
56+
y = np.random.randint(2, size=[batch_size, output_channels, 1, 1])
57+
58+
x = x.astype(config.floatX)
59+
y = y.astype(config.floatX)
60+
61+
x_mlp = x.flatten().reshape(batch_size, nvis)
62+
y_mlp = y.flatten().reshape(batch_size, output_channels)
63+
64+
# Initialize convnet with random weights.
65+
66+
conv_model = MLP(
67+
input_space=Conv2DSpace(shape=shape,
68+
axes=['b', 0, 1, 'c'],
69+
num_channels=1),
70+
layers=[ConvElemwise(layer_name='conv',
71+
nonlinearity=conv_nonlinearity,
72+
output_channels=output_channels,
73+
kernel_shape=shape,
74+
pool_shape=[1, 1],
75+
pool_stride=shape,
76+
irange=1.0)],
77+
batch_size=batch_size
78+
)
79+
80+
X = conv_model.get_input_space().make_theano_batch()
81+
Y = conv_model.get_target_space().make_theano_batch()
82+
Y_hat = conv_model.fprop(X)
83+
g = theano.function([X], Y_hat)
84+
85+
# Construct an equivalent MLP which gives the same output
86+
# after flattening both.
87+
mlp_model = MLP(
88+
layers=[mlp_nonlinearity(dim=output_channels,
89+
layer_name='mlp',
90+
irange=1.0)],
91+
batch_size=batch_size,
92+
nvis=nvis
93+
)
94+
95+
W, b = conv_model.get_param_values()
96+
97+
W_mlp = np.zeros(shape=(output_channels, nvis), dtype=config.floatX)
98+
for k in range(output_channels):
99+
W_mlp[k] = W[k, 0].flatten()[::-1]
100+
W_mlp = W_mlp.T
101+
b_mlp = b.flatten()
102+
103+
mlp_model.set_param_values([W_mlp, b_mlp])
104+
105+
X1 = mlp_model.get_input_space().make_theano_batch()
106+
Y1 = mlp_model.get_target_space().make_theano_batch()
107+
Y1_hat = mlp_model.fprop(X1)
108+
f = theano.function([X1], Y1_hat)
109+
110+
# Check that the two models give the same output
111+
assert_allclose(f(x_mlp).flatten(), g(x).flatten(), rtol=1e-5, atol=5e-5)
112+
113+
if cost_implemented:
114+
# Check that the two models have the same costs
115+
mlp_cost = theano.function([X1, Y1], mlp_model.cost(Y1, Y1_hat))
116+
conv_cost = theano.function([X, Y], conv_model.cost(Y, Y_hat))
117+
assert_allclose(conv_cost(x, y), mlp_cost(x_mlp, y_mlp))
118+
else:
119+
# Check that both costs are not implemented
120+
assert_raises(NotImplementedError, conv_model.cost, Y, Y_hat)
121+
assert_raises(NotImplementedError, mlp_model.cost, Y1, Y1_hat)
122+
123+
124+
def test_all_costs():
125+
"""Check all instances of ConvNonLinearity.
126+
127+
Either they should be consistent with the corresponding subclass
128+
of `Linear`, or their `cost` method should not be implemented.
129+
"""
130+
131+
cases = [[SigmoidConvNonlinearity(), Sigmoid, True],
132+
[IdentityConvNonlinearity(), Linear, True],
133+
[TanhConvNonlinearity(), Tanh, False],
134+
[RectifierConvNonlinearity(), RectifiedLinear, False]]
135+
136+
for conv_nonlinearity, mlp_nonlinearity, cost_implemented in cases:
137+
check_case(conv_nonlinearity, mlp_nonlinearity, cost_implemented)
138+
139+
140+
if __name__ == "__main__":
141+
test_all_costs()

pylearn2/optimization/linesearch.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import theano
1010
import theano.tensor as TT
1111
from theano.ifelse import ifelse
12-
from theano.sandbox.scan import scan
12+
from theano import scan
1313
import numpy
1414

1515
one = TT.constant(numpy.asarray(1, dtype=theano.config.floatX))
@@ -109,27 +109,23 @@ def armijo(alpha0, alpha1, phi_a0, phi_a1):
109109
return [alpha1, alpha2, phi_a1, phi_a2], \
110110
theano.scan_module.until(end_condition)
111111

112-
states = []
113-
states += [TT.unbroadcast(TT.shape_padleft(alpha0), 0)]
114-
states += [TT.unbroadcast(TT.shape_padleft(alpha1), 0)]
115-
states += [TT.unbroadcast(TT.shape_padleft(phi_a0), 0)]
116-
states += [TT.unbroadcast(TT.shape_padleft(phi_a1), 0)]
112+
states = [alpha0, alpha1, phi_a0, phi_a1]
117113
# print 'armijo'
118114
rvals, _ = scan(
119115
armijo,
120-
states=states,
116+
outputs_info=states,
121117
n_steps=n_iters,
122118
name='armijo',
123119
mode=theano.Mode(linker='cvm'),
124120
profile=profile)
125121

126-
sol_scan = rvals[1][0]
122+
sol_scan = rvals[1][-1]
127123
a_opt = ifelse(csol1, one,
128124
ifelse(csol2, alpha1,
129125
sol_scan))
130126
score = ifelse(csol1, phi_a0,
131127
ifelse(csol2, phi_a1,
132-
rvals[2][0]))
128+
rvals[2][-1]))
133129
return a_opt, score
134130

135131

@@ -279,31 +275,26 @@ def while_search(alpha0, alpha1, phi_a0, phi_a1, derphi_a0, i_t,
279275
cond1,
280276
cond2,
281277
cond3)))
282-
states = []
283-
states += [TT.unbroadcast(TT.shape_padleft(alpha0), 0)]
284-
states += [TT.unbroadcast(TT.shape_padleft(alpha1), 0)]
285-
states += [TT.unbroadcast(TT.shape_padleft(phi_a0), 0)]
286-
states += [TT.unbroadcast(TT.shape_padleft(phi_a1), 0)]
287-
states += [TT.unbroadcast(TT.shape_padleft(derphi_a0), 0)]
278+
states = [alpha0, alpha1, phi_a0, phi_a1, derphi_a0]
288279
# i_t
289-
states += [TT.unbroadcast(TT.shape_padleft(zero), 0)]
280+
states.append(zero)
290281
# alpha_star
291-
states += [TT.unbroadcast(TT.shape_padleft(zero), 0)]
282+
states.append(zero)
292283
# phi_star
293-
states += [TT.unbroadcast(TT.shape_padleft(zero), 0)]
284+
states.append(zero)
294285
# derphi_star
295-
states += [TT.unbroadcast(TT.shape_padleft(zero), 0)]
286+
states.append(zero)
296287
# print 'while_search'
297288
outs, updates = scan(while_search,
298-
states=states,
289+
outputs_info=states,
299290
n_steps=maxiter,
300291
name='while_search',
301292
mode=theano.Mode(linker='cvm_nogc'),
302293
profile=profile)
303294
# print 'done_while_search'
304-
out3 = outs[-3][0]
305-
out2 = outs[-2][0]
306-
out1 = outs[-1][0]
295+
out3 = outs[-3][-1]
296+
out2 = outs[-2][-1]
297+
out1 = outs[-1][-1]
307298
alpha_star, phi_star, derphi_star = \
308299
ifelse(TT.eq(alpha1, zero),
309300
(nan, phi0, nan),
@@ -629,28 +620,19 @@ def while_zoom(phi_rec, a_rec, a_lo, a_hi, phi_hi,
629620
derphi_lo.name = 'derphi_lo'
630621
vderphi_aj = ifelse(cond1, nan, TT.switch(cond2, derphi_aj, nan),
631622
name='vderphi_aj')
632-
states = []
633-
states += [TT.unbroadcast(TT.shape_padleft(phi_rec), 0)]
634-
states += [TT.unbroadcast(TT.shape_padleft(a_rec), 0)]
635-
states += [TT.unbroadcast(TT.shape_padleft(a_lo), 0)]
636-
states += [TT.unbroadcast(TT.shape_padleft(a_hi), 0)]
637-
states += [TT.unbroadcast(TT.shape_padleft(phi_hi), 0)]
638-
states += [TT.unbroadcast(TT.shape_padleft(phi_lo), 0)]
639-
states += [TT.unbroadcast(TT.shape_padleft(derphi_lo), 0)]
640-
states += [TT.unbroadcast(TT.shape_padleft(zero), 0)]
641-
states += [TT.unbroadcast(TT.shape_padleft(zero), 0)]
642-
states += [TT.unbroadcast(TT.shape_padleft(zero), 0)]
623+
states = [phi_rec, a_rec, a_lo, a_hi, phi_hi, phi_lo, derphi_lo, zero, zero, zero]
624+
643625
# print'while_zoom'
644626
outs, updates = scan(while_zoom,
645-
states=states,
627+
outputs_info=states,
646628
n_steps=maxiter,
647629
name='while_zoom',
648630
mode=theano.Mode(linker='cvm_nogc'),
649631
profile=profile)
650632
# print 'done_while'
651-
a_star = ifelse(onlyif, a_j, outs[7][0], name='astar')
652-
val_star = ifelse(onlyif, phi_aj, outs[8][0], name='valstar')
653-
valprime = ifelse(onlyif, vderphi_aj, outs[9][0], name='valprime')
633+
a_star = ifelse(onlyif, a_j, outs[7][-1], name='astar')
634+
val_star = ifelse(onlyif, phi_aj, outs[8][-1], name='valstar')
635+
valprime = ifelse(onlyif, vderphi_aj, outs[9][-1], name='valprime')
654636

655637
## WARNING !! I ignore updates given by scan which I should not do !!!
656638
return a_star, val_star, valprime

0 commit comments

Comments
 (0)