Skip to content

Commit aaa2631

Browse files
committed
support py3, fix bug
1 parent ab4f8a5 commit aaa2631

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

gan_train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.nn.functional as F
2020

2121
import matplotlib.pyplot as plt
22-
%matplotlib
22+
#%matplotlib
2323

2424

2525
# custom weights initialization called on netG and netD
@@ -34,7 +34,7 @@ def update_learning_rate(optimizer, epoch, init_lr, decay_rate, lr_decay_epochs)
3434
lr = init_lr * (decay_rate**(epoch // lr_decay_epochs))
3535

3636
if epoch % lr_decay_epochs == 0:
37-
print 'LR set to {}'.format(lr)
37+
print('LR set to {}'.format(lr))
3838

3939
for param_group in optimizer.param_groups:
4040
param_group['lr'] = lr
@@ -113,13 +113,14 @@ def sample(self, N):
113113
return np.reshape(samples, (-1, 1))
114114

115115

116-
def GeneratorDistribution(object):
116+
class GeneratorDistribution(object):
117117
def __init__(self, range):
118118
self.range = range
119119

120120
def sample(self, N):
121121
samples = np.linspace(-self.range, self.range, N) + \
122122
np.random.random(N) * 0.01
123+
return samples
123124

124125

125126
class Generator(torch.nn.Module):
@@ -208,7 +209,7 @@ def forward(self, x):
208209
D_x = output.data.mean()
209210

210211
# train with fake
211-
z = torch.FloatTensor(gen_dist.sample(N))
212+
z = torch.FloatTensor(gen_dist.sample(N))[...,None] # (N_sample, N_channel)
212213
if use_cuda:
213214
z = z.cuda()
214215
zv = Variable(z)
@@ -242,8 +243,8 @@ def forward(self, x):
242243
decay_rate=0.95,
243244
lr_decay_epochs=150)
244245

245-
print '[%d/%d] Loss_D: %.4f Loss_G %.4f D(x): %.4f D(G(z)): %.4f / %.4f' \
246-
% (epoch, epochs, errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2)
246+
print('[%d/%d] Loss_D: %.4f Loss_G %.4f D(x): %.4f D(G(z)): %.4f / %.4f' \
247+
% (epoch, epochs, errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
247248

248249
if epoch % plot_every_epochs == 0:
249250
# Plot distribution

0 commit comments

Comments
 (0)