@@ -84,9 +84,11 @@ def tile_raster_images(X, img_shape, tile_shape,tile_spacing = (0,0),
8484 if X [i ] is None :
8585 # if channel is None, fill it with zeros of the correct
8686 # dtype
87+ dt = out_array .dtype
88+ if output_pixel_vals :
89+ dt = 'uint8'
8790 out_array [:,:,i ] = numpy .zeros (out_shape ,
88- dtype = 'uint8' if output_pixel_vals else out_array .dtype
89- )+ channel_defaults [i ]
91+ dtype = dt )+ channel_defaults [i ]
9092 else :
9193 # use a recurrent call to compute the channel and store it
9294 # in the output
@@ -99,7 +101,10 @@ def tile_raster_images(X, img_shape, tile_shape,tile_spacing = (0,0),
99101 Hs , Ws = tile_spacing
100102
101103 # generate a matrix to store the output
102- out_array = numpy .zeros (out_shape , dtype = 'uint8' if output_pixel_vals else X .dtype )
104+ dt = X .dtype
105+ if output_pixel_vals :
106+ dt = 'uint8'
107+ out_array = numpy .zeros (out_shape , dtype = dt )
103108
104109
105110 for tile_row in xrange (tile_shape [0 ]):
@@ -114,11 +119,14 @@ def tile_raster_images(X, img_shape, tile_shape,tile_spacing = (0,0),
114119 this_img = X [tile_row * tile_shape [1 ] + tile_col ].reshape (img_shape )
115120 # add the slice to the corresponding position in the
116121 # output array
122+ c = 1
123+ if output_pixel_vals :
124+ c = 255
117125 out_array [
118126 tile_row * (H + Hs ):tile_row * (H + Hs )+ H ,
119127 tile_col * (W + Ws ):tile_col * (W + Ws )+ W
120128 ] \
121- = this_img * ( 255 if output_pixel_vals else 1 )
129+ = this_img * c
122130 return out_array
123131
124132
0 commit comments