Skip to content

Commit 7bc7f36

Browse files
committed
backport to python2.4
1 parent e03eab8 commit 7bc7f36

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

code/rbm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def __init__(self, input=None, n_visible=784, n_hidden=500, \
8282

8383

8484
# initialize input layer for standalone RBM or layer0 of DBN
85-
self.input = input if input else T.dmatrix('input')
85+
self.input = input
86+
if not input:
87+
self.input = T.dmatrix('input')
8688

8789
self.W = W
8890
self.hbias = hbias

code/utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@ def tile_raster_images(X, img_shape, tile_shape,tile_spacing = (0,0),
8484
if X[i] is None:
8585
# if channel is None, fill it with zeros of the correct
8686
# dtype
87+
dt = out_array.dtype
88+
if output_pixel_vals:
89+
dt = 'uint8'
8790
out_array[:,:,i] = numpy.zeros(out_shape,
88-
dtype='uint8' if output_pixel_vals else out_array.dtype
89-
)+channel_defaults[i]
91+
dtype=dt)+channel_defaults[i]
9092
else:
9193
# use a recurrent call to compute the channel and store it
9294
# in the output
@@ -99,7 +101,10 @@ def tile_raster_images(X, img_shape, tile_shape,tile_spacing = (0,0),
99101
Hs, Ws = tile_spacing
100102

101103
# generate a matrix to store the output
102-
out_array = numpy.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype)
104+
dt = X.dtype
105+
if output_pixel_vals:
106+
dt = 'uint8'
107+
out_array = numpy.zeros(out_shape, dtype=dt)
103108

104109

105110
for tile_row in xrange(tile_shape[0]):
@@ -114,11 +119,14 @@ def tile_raster_images(X, img_shape, tile_shape,tile_spacing = (0,0),
114119
this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)
115120
# add the slice to the corresponding position in the
116121
# output array
122+
c = 1
123+
if output_pixel_vals:
124+
c = 255
117125
out_array[
118126
tile_row * (H+Hs):tile_row*(H+Hs)+H,
119127
tile_col * (W+Ws):tile_col*(W+Ws)+W
120128
] \
121-
= this_img * (255 if output_pixel_vals else 1)
129+
= this_img * c
122130
return out_array
123131

124132

0 commit comments

Comments
 (0)