18
18
from torch .autograd import Variable
19
19
import torch .nn .functional as F
20
20
21
+ import os
21
22
import matplotlib .pyplot as plt
23
+ from matplotlib import animation
24
+ import datetime
22
25
#%matplotlib
23
26
24
27
@@ -84,7 +87,7 @@ def samples(
84
87
return db , pd , pg
85
88
86
89
87
- def plot_distributions (samps , sample_range , ax ):
90
+ def plot_distributions (samps , sample_range , ax , save_img_name ):
88
91
ax .clear ()
89
92
db , pd , pg = samps
90
93
db_x = np .linspace (- sample_range , sample_range , len (db ))
@@ -98,8 +101,9 @@ def plot_distributions(samps, sample_range, ax):
98
101
plt .xlabel ('Data values' )
99
102
plt .ylabel ('Probability density' )
100
103
plt .legend ()
101
- plt .show ()
102
- plt .pause (0.05 )
104
+ plt .savefig (save_img_name )
105
+ #plt.show()
106
+ #plt.pause(0.05)
103
107
104
108
105
109
class DataDistribution (object ):
@@ -150,6 +154,58 @@ def forward(self, x):
150
154
out = F .sigmoid (self .linear4 (h2 ))
151
155
return out
152
156
157
+ def save_animation (anim_frames , anim_path , sample_range ):
158
+ f , ax = plt .subplots (figsize = (6 , 4 ))
159
+ f .suptitle ('1D Generative Adversarial Network' , fontsize = 15 )
160
+ plt .xlabel ('Data values' )
161
+ plt .ylabel ('Probability density' )
162
+ ax .set_xlim (- 6 , 6 )
163
+ ax .set_ylim (0 , 1.4 )
164
+ line_db , = ax .plot ([], [], label = 'decision boundary' )
165
+ line_pd , = ax .plot ([], [], label = 'real data' )
166
+ line_pg , = ax .plot ([], [], label = 'generated data' )
167
+ frame_number = ax .text (
168
+ 0.02 ,
169
+ 0.95 ,
170
+ '' ,
171
+ horizontalalignment = 'left' ,
172
+ verticalalignment = 'top' ,
173
+ transform = ax .transAxes
174
+ )
175
+ ax .legend ()
176
+
177
+ db , pd , _ = anim_frames [0 ]
178
+ db_x = np .linspace (- sample_range , sample_range , len (db ))
179
+ p_x = np .linspace (- sample_range , sample_range , len (pd ))
180
+
181
+ def init ():
182
+ line_db .set_data ([], [])
183
+ line_pd .set_data ([], [])
184
+ line_pg .set_data ([], [])
185
+ frame_number .set_text ('' )
186
+ return (line_db , line_pd , line_pg , frame_number )
187
+
188
+ def animate (i ):
189
+ frame_number .set_text (
190
+ 'Frame: {}/{}' .format (i , len (anim_frames ))
191
+ )
192
+ db , pd , pg = anim_frames [i ]
193
+ line_db .set_data (db_x , db )
194
+ line_pd .set_data (p_x , pd )
195
+ line_pg .set_data (p_x , pg )
196
+ return (line_db , line_pd , line_pg , frame_number )
197
+
198
+ anim = animation .FuncAnimation (
199
+ f ,
200
+ animate ,
201
+ init_func = init ,
202
+ frames = len (anim_frames ),
203
+ blit = True
204
+ )
205
+ anim .save (anim_path , fps = 30 , extra_args = ['-vcodec' , 'libx264' ])
206
+
207
+
208
+
153
209
154
210
N = 8
155
211
D_in = 1
@@ -158,6 +214,14 @@ def forward(self, x):
158
214
learning_rate = 0.005
159
215
epochs = 10000
160
216
plot_every_epochs = 1000
217
+ current_time = datetime .datetime .now ().strftime ("20%y%m%d_%H%M_%S" )
218
+ output_path = '/tmp/{}' .format (current_time )
219
+ if not os .path .exists (output_path ):
220
+ os .makedirs (output_path )
221
+ anim_path = output_path
222
+
223
+
224
+ anim_frames = []
161
225
162
226
use_cuda = torch .cuda .is_available ()
163
227
@@ -193,7 +257,7 @@ def forward(self, x):
193
257
############################
194
258
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
195
259
###########################
196
- # train with real
260
+ # train with real: maximize log(D(x))
197
261
netD .zero_grad ()
198
262
real_cpu = torch .FloatTensor (data_dist .sample (N ))
199
263
if use_cuda :
@@ -203,19 +267,19 @@ def forward(self, x):
203
267
xv = Variable (x )
204
268
labelv = Variable (label )
205
269
206
- output = netD (xv )
270
+ output = netD (xv ) # D(x)
207
271
errD_real = criterion (output , labelv )
208
272
errD_real .backward ()
209
273
D_x = output .data .mean ()
210
274
211
- # train with fake
275
+ # train with fake: maximize log(1 - D(G(z)))
212
276
z = torch .FloatTensor (gen_dist .sample (N ))[...,None ] # (N_sample, N_channel)
213
277
if use_cuda :
214
278
z = z .cuda ()
215
279
zv = Variable (z )
216
- fake = netG (zv )
280
+ fake = netG (zv ) # G(z)
217
281
labelv = Variable (label .fill_ (fake_label ))
218
- output = netD (fake .detach ())
282
+ output = netD (fake .detach ()) # D(G(z))
219
283
errD_fake = criterion (output , labelv )
220
284
errD_fake .backward ()
221
285
D_G_z1 = output .data .mean ()
@@ -232,7 +296,7 @@ def forward(self, x):
232
296
###########################
233
297
netG .zero_grad ()
234
298
labelv = Variable (label .fill_ (real_label ))
235
- output = netD (fake )
299
+ output = netD (fake ) # D(G(z))
236
300
errG = criterion (output , labelv )
237
301
errG .backward ()
238
302
D_G_z2 = output .data .mean ()
@@ -245,8 +309,11 @@ def forward(self, x):
245
309
246
310
print ('[%d/%d] Loss_D: %.4f Loss_G %.4f D(x): %.4f D(G(z)): %.4f / %.4f' \
247
311
% (epoch , epochs , errD .data [0 ], errG .data [0 ], D_x , D_G_z1 , D_G_z2 ))
248
-
312
+
249
313
if epoch % plot_every_epochs == 0 :
250
314
# Plot distribution
251
315
samps = samples ([netD , netG ], data_dist , gen_dist .range , N )
252
- plot_distributions (samps , gen_dist .range , ax )
316
+ anim_frames .append (samps )
317
+ plot_distributions (samps , gen_dist .range , ax , save_img_name = output_path + '/{:06}' .format (epoch ))
318
+
319
+ # save_animation(anim_frames, anim_path, gen_dist.range)
0 commit comments