@@ -174,36 +174,57 @@ def cd(self, lr = 0.1, persistent=None):
174174 if persistent :
175175 # Note that this works only if persistent is a shared variable
176176 updates [persistent ] = T .cast (nh_sample , dtype = theano .config .floatX )
177+ # pseudo-likelihood is a better proxy for PCD
178+ cost = self .get_pseudo_likelihood_cost (updates )
179+ else :
180+ # reconstruction cross-entropy is a better proxy for CD
181+ cost = self .get_reconstruction_cost (updates , nv_mean )
182+
183+ return cost , updates
177184
178- ####################################################
179- # stochastic approximation to the pseudo-likelihood
180- ####################################################
185+ def get_pseudo_likelihood_cost (self , updates ):
186+ """Stochastic approximation to the pseudo-likelihood"""
181187
182188 # index of bit i in expression p(x_i | x_{\i})
183189 bit_i_idx = theano .shared (value = 0 , name = 'bit_i_idx' )
184190
185191 # binarize the input image by rounding to nearest integer
186192 xi = T .iround (self .input )
193+
187194 # calculate free energy for the given bit configuration
188195 fe_xi = self .free_energy (xi )
189- # flip bit x_i and preserve all other bits x_{\i}
196+
197+ # flip bit x_i of matrix xi and preserve all other bits x_{\i}
198+ # Equivalent to xi[:,bit_i_idx] = 1-xi[:, bit_i_idx]
199+ # NB: slice(start,stop,step) is the python object used for
200+ # slicing, e.g. to index matrix x as follows: x[start:stop:step]
190201 xi_flip = T .setsubtensor (xi , 1 - xi [:, bit_i_idx ],
191- (slice (None ,None ,None ),bit_i_idx ))
202+ idx_list = (slice (None ,None ,None ),bit_i_idx ))
203+
192204 # calculate free energy with bit flipped
193205 fe_xi_flip = self .free_energy (xi_flip )
194206
195207 # equivalent to e^(-FE(x_i)) / (e^(-FE(x_i)) + e^(-FE(x_{\i})))
196208 cost = self .n_visible * T .log (T .nnet .sigmoid (fe_xi_flip - fe_xi ))
197209
198210 # increment bit_i_idx % number as part of updates
199- print type (self .n_visible )
200211 updates [bit_i_idx ] = (bit_i_idx + 1 ) % self .n_visible
201212
202- return updates , cost
213+ return cost
203214
215+ def get_reconstruction_cost (self , updates , nv_mean ):
216+ """Approximation to the reconstruction error"""
204217
205- def test_rbm ( learning_rate = 0.1 , training_epochs = 15 , \
206- dataset = 'mnist.pkl.gz' ):
218+ cross_entropy = T .mean (
219+ T .sum (self .input * T .log (nv_mean ) +
220+ (1 - self .input )* T .log (1 - nv_mean ), axis = 1 ))
221+
222+ return cross_entropy
223+
224+
225+
226+ def test_rbm (learning_rate = 0.1 , training_epochs = 15 ,
227+ dataset = 'mnist.pkl.gz' ):
207228 """
208229 Demonstrate ***
209230
@@ -242,7 +263,7 @@ def test_rbm( learning_rate=0.1, training_epochs = 15, \
242263 n_hidden = 500 ,numpy_rng = rng , theano_rng = theano_rng )
243264
244265 # get the cost and the gradient corresponding to one step of CD
245- updates , cost = rbm .cd (lr = learning_rate , persistent = persistent_chain )
266+ cost , updates = rbm .cd (lr = learning_rate , persistent = persistent_chain )
246267
247268
248269 #################################
@@ -296,11 +317,9 @@ def test_rbm( learning_rate=0.1, training_epochs = 15, \
296317 # find out the number of test samples
297318 number_of_test_samples = test_set_x .value .shape [0 ]
298319
299- # pick two initial starting points randomly
300- sample = rng .randint (number_of_test_samples - 20 )
301-
302- # Initialize the persistent chain with some sample from the test
303- persistent_vis_chain = theano .shared (test_set_x .value [sample :sample + 20 ])
320+ # pick random test examples, with which to initialize the persistent chain
321+ test_idx = rng .randint (number_of_test_samples - 20 )
322+ persistent_vis_chain = theano .shared (test_set_x .value [test_idx :test_idx + 20 ])
304323
305324 # define one step of Gibbs sampling (mf = mean-field)
306325 [hid_mf , hid_sample , vis_mf , vis_sample ] = rbm .gibbs_vhv (persistent_vis_chain )
@@ -338,7 +357,4 @@ def test_rbm( learning_rate=0.1, training_epochs = 15, \
338357 image .save ('sample_%i_step_%i.png' % (idx ,idx * jdx ))
339358
340359if __name__ == '__main__' :
341- lr = numpy .float (os .sys .argv [1 ])
342- print 'Using learning rate of ' , lr
343- print 'type of learning rate is ' , type (lr )
344- test_rbm (learning_rate = lr )
360+ test_rbm ()
0 commit comments