Skip to content

Commit 60393a8

Browse files
committed
can use GPU
1 parent 46f971a commit 60393a8

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

gan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@
6363

6464
discriminator.trainable = False
6565

66-
gan_input = keras.Input(shape=(latent_dim,))
67-
gan_output = discriminator(generator(gan_input))
66+
gan_input = keras.Input(shape=(latent_dim,))
67+
gan_output = discriminator(generator(gan_input))
6868
gan = keras.models.Model(gan_input, gan_output)
6969

7070
gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
@@ -92,7 +92,7 @@
9292
random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
9393

9494
generated_images = generator.predict(random_latent_vectors)
95-
95+
9696
stop = start + batch_size
9797
real_images = x_train[start: stop]
9898
combined_images = np.concatenate([generated_images, real_images])

gpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
print('checking what CPU / GPU you have available for Keras & Tensorflow')
2+
3+
from tensorflow.python.client import device_lib
4+
print(device_lib.list_local_devices())
5+

requirements.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,9 @@ numpy==1.13.0
77
opencv-python
88
pylint
99
tensorflow
10+
# get that GPU speed!
11+
tensorflow-gpu
1012
# html5lib newer version needed to get tensorboard to work
11-
html5lib==1.0.1
13+
html5lib==1.0.1
14+
# required for PIL -- image manipulation, at least in gan.py
15+
pillow

0 commit comments

Comments
 (0)