Skip to content

Commit cfa2018

Browse files
author
Mathew Donald Rogers
committed
Fixed cost functions for IdentityConvNonlinear and SigmoidConvNonlinear classes.
1 parent 7c6249f commit cfa2018

File tree

2 files changed

+82
-33
lines changed

2 files changed

+82
-33
lines changed

pylearn2/models/mlp.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2673,6 +2673,13 @@ def get_monitoring_channels_from_state(self, state, target,
26732673
rval = self._get_monitoring_channels_for_activations(state)
26742674

26752675
return rval
2676+
2677+
@wraps(Layer.cost)
2678+
def cost(self, Y, Y_hat):
2679+
raise NotImplementedError(
2680+
str(type(self)) + " does not implement cost function.")
2681+
2682+
26762683

26772684

26782685
class IdentityConvNonlinearity(ConvNonlinearity):
@@ -2701,6 +2708,30 @@ def get_monitoring_channels_from_state(self,
27012708
rval["misclass"] = T.cast(incorrect, config.floatX).mean()
27022709

27032710
return rval
2711+
2712+
@wraps(Linear.cost)
2713+
def cost(self, Y, Y_hat, batch_axis):
2714+
"""
2715+
Parameters
2716+
----------
2717+
Y : theano.gof.Variable
2718+
Output of `fprop`
2719+
Y_hat : theano.gof.Variable
2720+
Targets
2721+
batch_axis : integer
2722+
axis representing batch dimension
2723+
2724+
Returns
2725+
-------
2726+
cost : theano.gof.Variable
2727+
0-D tensor describing the cost
2728+
2729+
Notes
2730+
-----
2731+
Mean squared error across batch
2732+
"""
2733+
return T.sum(T.mean(T.sqr(Y-Y_hat), axis = batch_axis))
2734+
27042735

27052736

27062737
class RectifierConvNonlinearity(ConvNonlinearity):
@@ -2813,6 +2844,33 @@ def get_monitoring_channels_from_state(self, state, target,
28132844
rval['per_output_f1_min'] = f1.min()
28142845

28152846
return rval
2847+
2848+
@wraps(Linear.cost)
2849+
def cost(self, Y, Y_hat, batch_axis):
2850+
"""
2851+
Parameters
2852+
----------
2853+
Y : theano.gof.Variable
2854+
Output of `fprop`
2855+
Y_hat : theano.gof.Variable
2856+
Targets
2857+
batch_axis : integer
2858+
axis representing batch dimension
2859+
2860+
Returns
2861+
-------
2862+
cost : theano.gof.Variable
2863+
0-D tensor describing the cost
2864+
2865+
Notes
2866+
-----
2867+
Cost mean across units, mean across batch of KL divergence
2868+
KL(P || Q) where P is defined by Y and Q is defined by Y_hat
2869+
KL(P || Q) = p log p - p log q + (1-p) log (1-p) - (1-p) log (1-q)
2870+
"""
2871+
ave_total = kl(Y=Y, Y_hat=Y_hat, batch_axis=batch_axis)
2872+
ave = ave_total.mean()
2873+
return ave
28162874

28172875

28182876
class TanhConvNonlinearity(ConvNonlinearity):
@@ -3248,41 +3306,26 @@ def fprop(self, state_below):
32483306
p = self.output_normalization(p)
32493307

32503308
return p
3251-
3309+
3310+
@wraps(Layer.cost)
32523311
def cost(self, Y, Y_hat):
32533312
"""
3254-
Cost for convnets is hardcoded to be the cost for sigmoids.
3255-
TODO: move the cost into the non-linearity class.
3256-
32573313
Parameters
32583314
----------
32593315
Y : theano.gof.Variable
3260-
Output of `fprop`
3316+
Output of `fprop`
32613317
Y_hat : theano.gof.Variable
32623318
Targets
32633319
32643320
Returns
32653321
-------
32663322
cost : theano.gof.Variable
32673323
0-D tensor describing the cost
3268-
3269-
Notes
3270-
-----
3271-
Cost mean across units, mean across batch of KL divergence
3272-
KL(P || Q) where P is defined by Y and Q is defined by Y_hat
3273-
KL(P || Q) = p log p - p log q + (1-p) log (1-p) - (1-p) log (1-q)
3274-
"""
3275-
assert self.nonlin.non_lin_name == "sigmoid", ("ConvElemwise "
3276-
"supports "
3277-
"cost function "
3278-
"for only "
3279-
"sigmoid layer "
3280-
"for now.")
3324+
"""
3325+
32813326
batch_axis = self.output_space.get_batch_axis()
3282-
ave_total = kl(Y=Y, Y_hat=Y_hat, batch_axis=batch_axis)
3283-
ave = ave_total.mean()
3284-
return ave
3285-
3327+
return self.nonlin.cost(Y=Y, Y_hat=Y_hat, batch_axis=batch_axis)
3328+
32863329

32873330
class ConvRectifiedLinear(ConvElemwise):
32883331

pylearn2/models/tests/test_costs.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11

22
"""
3-
Notes: Cost function is not implemented for IdentityConvNonlinearity, RectifierConvNonlinearity, TanhConvNonlinearity. It is bugged for SigmoidConvNonlinearity, but we are
4-
not triggering that bug here. The cost function is not implemented for standard mlp RectifiedLinear or Tanh.
3+
Notes: Cost function is not implemented for IdentityConvNonlinearity, RectifierConvNonlinearity,
4+
TanhConvNonlinearity. It is bugged for SigmoidConvNonlinearity, but we are not triggering the
5+
bug here. The cost function is also not implemented for standard mlp RectifiedLinear or Tanh.
56
"""
67

78

@@ -17,25 +18,25 @@
1718
from pylearn2.space import Conv2DSpace
1819
from pylearn2.models.mlp import SigmoidConvNonlinearity, TanhConvNonlinearity, IdentityConvNonlinearity, RectifierConvNonlinearity
1920

20-
21-
22-
2321
#def test_costs():
2422

2523
# Create fake data
2624
np.random.seed(12345)
2725

2826

29-
r = 13
30-
s = 11
27+
r = 31
28+
s = 21
3129
shape = [r, s]
3230
nvis = r*s
33-
output_channels = 17
34-
batch_size = 1
31+
output_channels = 13
32+
batch_size = 103
3533

3634
x = np.random.rand(batch_size, r, s, 1)
3735
y = np.random.randint(2, size = [batch_size, output_channels, 1 ,1])
3836

37+
x = x.astype('float32')
38+
y = y.astype('float32')
39+
3940
x_mlp = x.flatten().reshape(batch_size, nvis)
4041
y_mlp = y.flatten().reshape(batch_size, output_channels)
4142

@@ -63,11 +64,15 @@
6364
)
6465

6566
W, b = conv_model.get_param_values()
67+
W = W.astype('float32')
68+
b = b.astype('float32')
6669
W_mlp = np.zeros(shape = (output_channels, nvis))
6770
for k in range(output_channels):
6871
W_mlp[k] = W[k, 0].flatten()[::-1]
6972
W_mlp = W_mlp.T
7073
b_mlp = b.flatten()
74+
W_mlp = W_mlp.astype('float32')
75+
b_mlp = b_mlp.astype('float32')
7176
mlp_model.set_param_values([W_mlp, b_mlp])
7277

7378
X1 = mlp_model.get_input_space().make_theano_batch()
@@ -77,13 +82,14 @@
7782

7883

7984
# Check that the two models give the same throughput
80-
assert np.linalg.norm(f(x_mlp).flatten() - g(x).flatten()) < 10**-10
81-
print "Fprop ok"
85+
assert np.linalg.norm(f(x_mlp).flatten() - g(x).flatten()) < 10**-3
86+
print "f-prop ok"
8287

8388
# Cost functions:
8489
mlp_cost = theano.function([X1, Y1], mlp_model.cost(Y1, Y1_hat))
8590
print "mlp_cost = "+str(mlp_cost(x_mlp, y_mlp))
8691

92+
batch_axis = T.scalar()
8793
conv_cost = theano.function([X, Y], conv_model.cost(Y, Y_hat))
8894
print "conv_cost = "+str(conv_cost(x,y))
8995

0 commit comments

Comments
 (0)