|
| 1 | +"""A library that contains functions for analyzing the trained entropy autoencoders.""" |
| 2 | + |
| 3 | +import matplotlib |
| 4 | +try: |
| 5 | + import PyQt5 |
| 6 | + matplotlib.use('Qt5Agg') |
| 7 | +except ImportError: |
| 8 | + matplotlib.use('Agg') |
| 9 | +import matplotlib.pyplot as plt |
| 10 | +import numpy |
| 11 | +import os |
| 12 | +import scipy.stats |
| 13 | + |
| 14 | +import eae.graph.constants as csts |
| 15 | +import tools.tools as tls |
| 16 | + |
| 17 | +def activate_latent_variable(sess, isolated_decoder, h_in, w_in, bin_widths, row_activation, col_activation, |
| 18 | + idx_map_activation, activation_value, map_mean, height_crop, width_crop, path_to_crop): |
| 19 | + """Activates one latent variable and deactivates the others. |
| 20 | + |
| 21 | + One latent variable is activated and the others |
| 22 | + are deactivated. Then, the latent variable feature |
| 23 | + maps are quantized. Finally, the quantized latent |
| 24 | + variable feature maps are passed through the decoder |
| 25 | + of the entropy autoencoder. |
| 26 | + |
| 27 | + Parameters |
| 28 | + ---------- |
| 29 | + sess : Session |
| 30 | + Session that runs the graph. |
| 31 | + isolated_decoder : IsolatedDecoder |
| 32 | + Decoder of the entropy autoencoder. The graph |
| 33 | + of the decoder is built to process one example |
| 34 | + at a time. |
| 35 | + h_in : int |
| 36 | + Height of the images returned by the |
| 37 | + isolated decoder. |
| 38 | + w_in : int |
| 39 | + Width of the images returned by the |
| 40 | + isolated decoder. |
| 41 | + bin_widths : numpy.ndarray |
| 42 | + 1D array with data-type `numpy.float32`. |
| 43 | + Quantization bin widths at the end of the |
| 44 | + training. |
| 45 | + row_activation : int |
| 46 | + Row of the activated latent variable in the |
| 47 | + latent variable feature map of index `idx_map_activation`. |
| 48 | + col_activation : int |
| 49 | + Column of the activated latent variable in the |
| 50 | + latent variable feature map of index `idx_map_activation`. |
| 51 | + idx_map_activation : int |
| 52 | + Index of the latent variable feature map |
| 53 | + containing the activated latent variable. |
| 54 | + activation_value : float |
| 55 | + Activation value. |
| 56 | + map_mean : numpy.ndarray |
| 57 | + 1D array with data-type `numpy.float32`. |
| 58 | + Latent variable feature map means. |
| 59 | + height_crop : int |
| 60 | + Height of the crop of the decoder output. |
| 61 | + width_crop : int |
| 62 | + Width of the crop of the decoder output. |
| 63 | + path_to_crop : str |
| 64 | + Path to the saved crop of the decoder |
| 65 | + output. The path ends with ".png". |
| 66 | + |
| 67 | + """ |
| 68 | + y_float32 = numpy.tile(numpy.reshape(map_mean, (1, 1, 1, csts.NB_MAPS_3)), |
| 69 | + (1, h_in//csts.STRIDE_PROD, w_in//csts.STRIDE_PROD, 1)) |
| 70 | + y_float32[0, row_activation, col_activation, idx_map_activation] = activation_value |
| 71 | + quantized_y_float32 = tls.quantize_per_map(y_float32, bin_widths) |
| 72 | + reconstruction_float32 = sess.run( |
| 73 | + isolated_decoder.node_reconstruction, |
| 74 | + feed_dict={isolated_decoder.node_quantized_y:quantized_y_float32} |
| 75 | + ) |
| 76 | + reconstruction_uint8 = numpy.squeeze(tls.cast_bt601(reconstruction_float32), axis=(0, 3)) |
| 77 | + tls.save_image(path_to_crop, |
| 78 | + reconstruction_uint8[0:height_crop, 0:width_crop]) |
| 79 | + |
| 80 | +def fit_maps(y_float32, idx_map_exception, path_to_histogram_locations, path_to_histogram_scales, paths): |
| 81 | + """Fits a Laplace density to the normed histogram of each latent variable feature map. |
| 82 | + |
| 83 | + Parameters |
| 84 | + ---------- |
| 85 | + y_float32 : numpy.ndarray |
| 86 | + 4D array with data-type `numpy.float32`. |
| 87 | + Latent variables. `y_float32[i, :, :, j]` |
| 88 | + is the jth latent variable feature map of |
| 89 | + the ith example. |
| 90 | + idx_map_exception : int |
| 91 | + Index of the latent variable feature map |
| 92 | + that is not compressed as the other maps. |
| 93 | + path_to_histogram_locations : str |
| 94 | + Path to the histogram of the Laplace locations. The |
| 95 | + path ends with ".png". |
| 96 | + path_to_histogram_scales : str |
| 97 | + Path to the histogram of the Laplace scales. The |
| 98 | + path ends with ".png". |
| 99 | + paths : list |
| 100 | + `paths[i]` is the path to the fitted normed histogram |
| 101 | + for the ith latent variable feature map. Each path ends |
| 102 | + with ".png". |
| 103 | + |
| 104 | + Raises |
| 105 | + ------ |
| 106 | + ValueError |
| 107 | + If `len(paths)` is not equal to `y_float32.shape[3]`. |
| 108 | + |
| 109 | + """ |
| 110 | + if len(paths) != y_float32.shape[3]: |
| 111 | + raise ValueError('`len(paths)` is not equal to `y_float32.shape[3]`.') |
| 112 | + locations = [] |
| 113 | + scales = [] |
| 114 | + for i in range(y_float32.shape[3]): |
| 115 | + map_float32 = y_float32[:, :, :, i] |
| 116 | + edge_left = numpy.floor(numpy.amin(map_float32)).item() |
| 117 | + edge_right = numpy.ceil(numpy.amax(map_float32)).item() |
| 118 | + |
| 119 | + # The grid below contains 50 points |
| 120 | + # per unit interval. |
| 121 | + grid = numpy.linspace(edge_left, |
| 122 | + edge_right, |
| 123 | + num=50*int(edge_right - edge_left) + 1) |
| 124 | + |
| 125 | + # Let's assume that `map_float32` contains i.i.d samples |
| 126 | + # from an unknown probability density function. The two |
| 127 | + # equations below result from the minimization of the |
| 128 | + # Kullback-Lieber divergence of the unknown probability |
| 129 | + # density function from our statistical model (Laplace |
| 130 | + # density of location `laplace_location` and scale |
| 131 | + # `laplace_scale`). Note that this minimization is |
| 132 | + # equivalent to the maximum likelihood estimator. |
| 133 | + # To dive into the details, see: |
| 134 | + # "Estimating distributions and densities". 36-402, |
| 135 | + # advanced data analysis, CMU, 27 January 2011. |
| 136 | + laplace_location = numpy.mean(map_float32).item() |
| 137 | + laplace_scale = numpy.mean(numpy.absolute(map_float32 - laplace_location)).item() |
| 138 | + laplace_pdf = scipy.stats.laplace.pdf(grid, |
| 139 | + loc=laplace_location, |
| 140 | + scale=laplace_scale) |
| 141 | + handle = [plt.plot(grid, laplace_pdf, color='red')[0]] |
| 142 | + hist, bin_edges = numpy.histogram(map_float32, |
| 143 | + bins=60, |
| 144 | + density=True) |
| 145 | + plt.bar(bin_edges[0:60], |
| 146 | + hist, |
| 147 | + width=bin_edges[1] - bin_edges[0], |
| 148 | + align='edge', |
| 149 | + color='blue') |
| 150 | + plt.title('Latent variable feature map {}'.format(i + 1)) |
| 151 | + plt.legend(handle, |
| 152 | + [r'$f( . ; {0}, {1})$'.format(str(round(laplace_location, 2)), str(round(laplace_scale, 2)))], |
| 153 | + prop={'size': 30}, |
| 154 | + loc=9) |
| 155 | + plt.savefig(paths[i]) |
| 156 | + plt.clf() |
| 157 | + if i != idx_map_exception: |
| 158 | + locations.append(laplace_location) |
| 159 | + scales.append(laplace_scale) |
| 160 | + |
| 161 | + # `nb_kept` must be equal to `y_float32.shape[3] - 1`. |
| 162 | + nb_kept = len(locations) |
| 163 | + tls.histogram(numpy.array(locations), |
| 164 | + 'Histogram of {} locations'.format(nb_kept), |
| 165 | + path_to_histogram_locations) |
| 166 | + tls.histogram(numpy.array(scales), |
| 167 | + 'Histogram of {} scales'.format(nb_kept), |
| 168 | + path_to_histogram_scales) |
| 169 | + |
| 170 | +def mask_maps(y_float32, sess, isolated_decoder, bin_widths, idx_unmasked_map, map_mean, height_crop, width_crop, paths): |
| 171 | + """Masks all the latent variable feature maps except one. |
| 172 | + |
| 173 | + All the latent variable feature maps except one |
| 174 | + are masked. Then, the latent variable feature maps |
| 175 | + are quantized. Finally, the quantized latent |
| 176 | + variable feature maps are passed through the |
| 177 | + decoder of the entropy autoencoder. |
| 178 | + |
| 179 | + Parameters |
| 180 | + ---------- |
| 181 | + y_float32 : numpy.ndarray |
| 182 | + 4D array with data-type `numpy.float32`. |
| 183 | + Latent variables. `y_float32[i, :, :, j]` |
| 184 | + is the jth latent variable feature map of |
| 185 | + the ith example. |
| 186 | + sess : Session |
| 187 | + Session that runs the graph. |
| 188 | + isolated_decoder : IsolatedDecoder |
| 189 | + Decoder of the entropy autoencoder. The graph |
| 190 | + of the decoder is built to process one example |
| 191 | + at a time. |
| 192 | + bin_widths : numpy.ndarray |
| 193 | + 1D array with data-type `numpy.float32`. |
| 194 | + Quantization bin widths at the end of the |
| 195 | + training. |
| 196 | + idx_unmasked_map : int |
| 197 | + Index of the unmasked latent variable |
| 198 | + feature map. |
| 199 | + map_mean : numpy.ndarray |
| 200 | + 1D array with data-type `numpy.float32`. |
| 201 | + Latent variable feature map means. |
| 202 | + height_crop : int |
| 203 | + Height of the crop of the decoder output. |
| 204 | + width_crop : int |
| 205 | + Width of the crop of the decoder output. |
| 206 | + paths : list |
| 207 | + The ith string in this list is the path |
| 208 | + to the ith saved crop of the decoder output. |
| 209 | + Each path ends with ".png". |
| 210 | + |
| 211 | + Raises |
| 212 | + ------ |
| 213 | + ValueError |
| 214 | + If `len(paths)` is not equal to `y_float32.shape[0]`. |
| 215 | + |
| 216 | + """ |
| 217 | + if len(paths) != y_float32.shape[0]: |
| 218 | + raise ValueError('`len(paths)` is not equal to `y_float32.shape[0]`.') |
| 219 | + |
| 220 | + # The same latent variable feature map is |
| 221 | + # iteratively overwritten in the loop below. |
| 222 | + masked_y_float32 = numpy.tile(numpy.reshape(map_mean, (1, 1, 1, y_float32.shape[3])), |
| 223 | + (1, y_float32.shape[1], y_float32.shape[2], 1)) |
| 224 | + for i in range(y_float32.shape[0]): |
| 225 | + masked_y_float32[0, :, :, idx_unmasked_map] = y_float32[i, :, :, idx_unmasked_map] |
| 226 | + quantized_y_float32 = tls.quantize_per_map(masked_y_float32, bin_widths) |
| 227 | + reconstruction_float32 = sess.run( |
| 228 | + isolated_decoder.node_reconstruction, |
| 229 | + feed_dict={isolated_decoder.node_quantized_y:quantized_y_float32} |
| 230 | + ) |
| 231 | + reconstruction_uint8 = numpy.squeeze(tls.cast_bt601(reconstruction_float32), |
| 232 | + axis=(0, 3)) |
| 233 | + tls.save_image(paths[i], |
| 234 | + reconstruction_uint8[0:height_crop, 0:width_crop]) |
| 235 | + |
| 236 | + |
0 commit comments