Skip to content

Commit 512e951

Browse files
committed
install_env
1 parent 10895b2 commit 512e951

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

gan_train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def animate(i):
207207

208208

209209

210-
N = 8
211-
D_in = 1
212-
H = 4
210+
N = 8 # batch size
211+
D_in = 1 # input size of D
212+
H = 4 # numbr of hidden neurons
213213
D_out = 1
214214
learning_rate = 0.005
215215
epochs = 10000
@@ -292,12 +292,12 @@ def animate(i):
292292
lr_decay_epochs=150)
293293

294294
############################
295-
# (2) Update G network: maximize log(D(G(z)))
295+
# (2) Update G network: maximize log(D(G(z))): guide D make wrong prediction: G(z) --> real_label(1)
296296
###########################
297297
netG.zero_grad()
298298
labelv = Variable(label.fill_(real_label))
299299
output = netD(fake) # D(G(z))
300-
errG = criterion(output, labelv)
300+
errG = criterion(output, labelv)
301301
errG.backward()
302302
D_G_z2 = output.data.mean()
303303
optimizerG.step()

install_env.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pip install -r requirements.txt
2+
pip install datetime
3+
pip3 install torch torchvision # or use `conda install pytorch torchvision -c pytorch`

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
matplotlib==1.5.3
2+
numpy==1.11.3
3+
scipy==0.17.0
4+
seaborn==0.7.1
5+
tensorflow==1.2.0

0 commit comments

Comments
 (0)