Skip to content

Commit 10895b2

Browse files
committed
save plt figure
1 parent aaa2631 commit 10895b2

File tree

2 files changed

+79
-11
lines changed

2 files changed

+79
-11
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.sw*

gan_train.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from torch.autograd import Variable
1919
import torch.nn.functional as F
2020

21+
import os
2122
import matplotlib.pyplot as plt
23+
from matplotlib import animation
24+
import datetime
2225
#%matplotlib
2326

2427

@@ -84,7 +87,7 @@ def samples(
8487
return db, pd, pg
8588

8689

87-
def plot_distributions(samps, sample_range, ax):
90+
def plot_distributions(samps, sample_range, ax, save_img_name):
8891
ax.clear()
8992
db, pd, pg = samps
9093
db_x = np.linspace(-sample_range, sample_range, len(db))
@@ -98,8 +101,9 @@ def plot_distributions(samps, sample_range, ax):
98101
plt.xlabel('Data values')
99102
plt.ylabel('Probability density')
100103
plt.legend()
101-
plt.show()
102-
plt.pause(0.05)
104+
plt.savefig(save_img_name)
105+
#plt.show()
106+
#plt.pause(0.05)
103107

104108

105109
class DataDistribution(object):
@@ -150,6 +154,58 @@ def forward(self, x):
150154
out = F.sigmoid(self.linear4(h2))
151155
return out
152156

157+
def save_animation(anim_frames, anim_path, sample_range):
158+
f, ax = plt.subplots(figsize=(6, 4))
159+
f.suptitle('1D Generative Adversarial Network', fontsize=15)
160+
plt.xlabel('Data values')
161+
plt.ylabel('Probability density')
162+
ax.set_xlim(-6, 6)
163+
ax.set_ylim(0, 1.4)
164+
line_db, = ax.plot([], [], label='decision boundary')
165+
line_pd, = ax.plot([], [], label='real data')
166+
line_pg, = ax.plot([], [], label='generated data')
167+
frame_number = ax.text(
168+
0.02,
169+
0.95,
170+
'',
171+
horizontalalignment='left',
172+
verticalalignment='top',
173+
transform=ax.transAxes
174+
)
175+
ax.legend()
176+
177+
db, pd, _ = anim_frames[0]
178+
db_x = np.linspace(-sample_range, sample_range, len(db))
179+
p_x = np.linspace(-sample_range, sample_range, len(pd))
180+
181+
def init():
182+
line_db.set_data([], [])
183+
line_pd.set_data([], [])
184+
line_pg.set_data([], [])
185+
frame_number.set_text('')
186+
return (line_db, line_pd, line_pg, frame_number)
187+
188+
def animate(i):
189+
frame_number.set_text(
190+
'Frame: {}/{}'.format(i, len(anim_frames))
191+
)
192+
db, pd, pg = anim_frames[i]
193+
line_db.set_data(db_x, db)
194+
line_pd.set_data(p_x, pd)
195+
line_pg.set_data(p_x, pg)
196+
return (line_db, line_pd, line_pg, frame_number)
197+
198+
anim = animation.FuncAnimation(
199+
f,
200+
animate,
201+
init_func=init,
202+
frames=len(anim_frames),
203+
blit=True
204+
)
205+
anim.save(anim_path, fps=30, extra_args=['-vcodec', 'libx264'])
206+
207+
208+
153209

154210
N = 8
155211
D_in = 1
@@ -158,6 +214,14 @@ def forward(self, x):
158214
learning_rate = 0.005
159215
epochs = 10000
160216
plot_every_epochs = 1000
217+
current_time = datetime.datetime.now().strftime("20%y%m%d_%H%M_%S")
218+
output_path = '/tmp/{}'.format(current_time)
219+
if not os.path.exists(output_path):
220+
os.makedirs(output_path)
221+
anim_path = output_path
222+
223+
224+
anim_frames = []
161225

162226
use_cuda = torch.cuda.is_available()
163227

@@ -193,7 +257,7 @@ def forward(self, x):
193257
############################
194258
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
195259
###########################
196-
# train with real
260+
# train with real: maximize log(D(x))
197261
netD.zero_grad()
198262
real_cpu = torch.FloatTensor(data_dist.sample(N))
199263
if use_cuda:
@@ -203,19 +267,19 @@ def forward(self, x):
203267
xv = Variable(x)
204268
labelv = Variable(label)
205269

206-
output = netD(xv)
270+
output = netD(xv) # D(x)
207271
errD_real = criterion(output, labelv)
208272
errD_real.backward()
209273
D_x = output.data.mean()
210274

211-
# train with fake
275+
# train with fake: maximize log(1 - D(G(z)))
212276
z = torch.FloatTensor(gen_dist.sample(N))[...,None] # (N_sample, N_channel)
213277
if use_cuda:
214278
z = z.cuda()
215279
zv = Variable(z)
216-
fake = netG(zv)
280+
fake = netG(zv) # G(z)
217281
labelv = Variable(label.fill_(fake_label))
218-
output = netD(fake.detach())
282+
output = netD(fake.detach()) # D(G(z))
219283
errD_fake = criterion(output, labelv)
220284
errD_fake.backward()
221285
D_G_z1 = output.data.mean()
@@ -232,7 +296,7 @@ def forward(self, x):
232296
###########################
233297
netG.zero_grad()
234298
labelv = Variable(label.fill_(real_label))
235-
output = netD(fake)
299+
output = netD(fake) # D(G(z))
236300
errG = criterion(output, labelv)
237301
errG.backward()
238302
D_G_z2 = output.data.mean()
@@ -245,8 +309,11 @@ def forward(self, x):
245309

246310
print('[%d/%d] Loss_D: %.4f Loss_G %.4f D(x): %.4f D(G(z)): %.4f / %.4f' \
247311
% (epoch, epochs, errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
248-
312+
249313
if epoch % plot_every_epochs == 0:
250314
# Plot distribution
251315
samps = samples([netD, netG], data_dist, gen_dist.range, N)
252-
plot_distributions(samps, gen_dist.range, ax)
316+
anim_frames.append(samps)
317+
plot_distributions(samps, gen_dist.range, ax, save_img_name = output_path+'/{:06}'.format(epoch))
318+
319+
# save_animation(anim_frames, anim_path, gen_dist.range)

0 commit comments

Comments
 (0)