Skip to content

Commit 0f36d1b

Browse files
authored
Add GPU configuration for TensorFlow models (Sequential and Functional) (#572)
* Fix grammar in documentation and add GPU configuration class * Updated Tensorflow dependencies in build.gradle and cleaned up LeNetClassicWithGPUMemoryConfig.kt
1 parent a3d633e commit 0f36d1b

File tree

49 files changed

+468
-117
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+468
-117
lines changed

examples/build.gradle

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ dependencies {
1818
testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
1919
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.2'
2020
// to run on GPU (if CUDA is updated and machine with NVIDIA onboard)
21-
/*implementation 'org.tensorflow:libtensorflow:1.15.0'
22-
implementation 'org.tensorflow:libtensorflow_jni_gpu:1.15.0'
23-
api 'com.microsoft.onnxruntime:onnxruntime_gpu:1.12.1'
24-
*/
21+
// implementation 'org.tensorflow:libtensorflow:1.15.0'
22+
// implementation 'org.tensorflow:libtensorflow_jni_gpu:1.15.0'
23+
// api 'com.microsoft.onnxruntime:onnxruntime_gpu:1.12.1'
24+
2525
}
2626

2727
def publishedArtifactsVersion = System.getenv("KOTLIN_DL_RELEASE_VERSION")
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
3+
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
4+
*/
5+
6+
package examples.cnn.mnist.advanced
7+
8+
import org.jetbrains.kotlinx.dl.api.core.GpuConfiguration
9+
import org.jetbrains.kotlinx.dl.api.core.Sequential
10+
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
11+
import org.jetbrains.kotlinx.dl.api.core.initializer.Constant
12+
import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotNormal
13+
import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros
14+
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D
15+
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding
16+
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
17+
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
18+
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.AvgPool2D
19+
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
20+
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
21+
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
22+
import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
23+
import org.jetbrains.kotlinx.dl.api.core.optimizer.ClipGradientByValue
24+
import org.jetbrains.kotlinx.dl.dataset.embedded.NUMBER_OF_CLASSES
25+
import org.jetbrains.kotlinx.dl.dataset.embedded.mnist
26+
import org.jetbrains.kotlinx.dl.impl.summary.logSummary
27+
28+
private const val EPOCHS = 3
29+
private const val TRAINING_BATCH_SIZE = 1000
30+
private const val NUM_CHANNELS = 1L
31+
private const val IMAGE_SIZE = 28L
32+
private const val SEED = 12L
33+
private const val TEST_BATCH_SIZE = 1000
34+
35+
/**
36+
* This example shows how to do image classification from scratch using [lenet5Classic], without leveraging pre-trained weights or a pre-made model.
37+
* We demonstrate the workflow on the Mnist classification dataset.
38+
*
39+
* It could be run only with enabled tensorflow GPU dependencies
40+
*
41+
* It includes:
42+
* - dataset loading from S3
43+
* - model compilation
44+
* - model summary
45+
* - model training
46+
* - model evaluation
47+
*/
48+
fun lenetClassicWithGPUMemoryConfig() {
49+
50+
val layersActivation = Activations.Tanh
51+
val classifierActivation = Activations.Linear
52+
53+
val model = Sequential.of(
54+
Input(
55+
IMAGE_SIZE,
56+
IMAGE_SIZE,
57+
NUM_CHANNELS,
58+
),
59+
Conv2D(
60+
filters = 6,
61+
kernelSize = 5,
62+
strides = 1,
63+
activation = layersActivation,
64+
kernelInitializer = GlorotNormal(SEED),
65+
biasInitializer = Zeros(),
66+
padding = ConvPadding.SAME,
67+
),
68+
AvgPool2D(
69+
poolSize = 2,
70+
strides = 2,
71+
padding = ConvPadding.VALID,
72+
),
73+
Conv2D(
74+
filters = 16,
75+
kernelSize = 5,
76+
strides = 1,
77+
activation = layersActivation,
78+
kernelInitializer = GlorotNormal(SEED),
79+
biasInitializer = Zeros(),
80+
padding = ConvPadding.SAME,
81+
),
82+
AvgPool2D(
83+
poolSize = 2,
84+
strides = 2,
85+
padding = ConvPadding.VALID,
86+
),
87+
Flatten(),
88+
Dense(
89+
outputSize = 120,
90+
activation = layersActivation,
91+
kernelInitializer = GlorotNormal(SEED),
92+
biasInitializer = Constant(0.1f),
93+
),
94+
Dense(
95+
outputSize = 84,
96+
activation = Activations.Tanh,
97+
kernelInitializer = GlorotNormal(SEED),
98+
biasInitializer = Constant(0.1f),
99+
),
100+
Dense(
101+
outputSize = NUMBER_OF_CLASSES,
102+
activation = classifierActivation,
103+
kernelInitializer = GlorotNormal(SEED),
104+
biasInitializer = Constant(0.1f),
105+
),
106+
gpuConfiguration = GpuConfiguration(allowGrowth = true)
107+
)
108+
109+
val (train, test) = mnist()
110+
111+
model.use {
112+
it.compile(
113+
optimizer = Adam(clipGradient = ClipGradientByValue(0.1f)),
114+
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
115+
metric = Metrics.ACCURACY
116+
)
117+
118+
it.logSummary()
119+
120+
it.fit(dataset = train, epochs = EPOCHS, batchSize = TRAINING_BATCH_SIZE)
121+
122+
val accuracy = it.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]
123+
124+
println("Accuracy: $accuracy")
125+
}
126+
}
127+
128+
/** */
129+
fun main(): Unit = lenetClassicWithGPUMemoryConfig()
130+

examples/src/main/kotlin/examples/transferlearning/lenet/Example_1_Load_model_with_weights_and_evaluate_it.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import org.jetbrains.kotlinx.dl.impl.summary.logSummary
1717
import java.io.File
1818

1919
/**
20-
* This examples demonstrates the inference concept:
20+
* This example demonstrates the inference concept:
2121
* - Weights are loaded from .h5 file, configuration is loaded from .json file.
2222
* - Model is evaluated after loading to obtain accuracy value.
2323
* - No additional training.

examples/src/main/kotlin/examples/transferlearning/lenet/Example_2_Load_model_without_weights_and_evaluate_it.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import org.jetbrains.kotlinx.dl.dataset.embedded.fashionMnist
1313
import org.jetbrains.kotlinx.dl.impl.summary.logSummary
1414

1515
/**
16-
* This examples demonstrates the weird inference case:
16+
* This example demonstrates the weird inference case:
1717
* - Weights are not loaded, but initialized via initialized defined in configuration, configuration is loaded from .json file.
1818
* - Model is evaluated after loading to obtain accuracy value.
1919
* - No additional training.

examples/src/main/kotlin/examples/transferlearning/lenet/Example_3_Additional_training.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import org.jetbrains.kotlinx.dl.dataset.embedded.fashionMnist
1414
import org.jetbrains.kotlinx.dl.impl.summary.logSummary
1515

1616
/**
17-
* This examples demonstrates the transfer learning concept:
17+
* This example demonstrates the transfer learning concept:
1818
* - Weights are loaded from .h5 file, configuration is loaded from .json file.
1919
* - All model weights are not frozen, and can be changed during the training.
2020
* - No new layers are added.

examples/src/main/kotlin/examples/transferlearning/lenet/Example_4_Additional_training_and_freezing.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ import org.jetbrains.kotlinx.dl.dataset.embedded.fashionMnist
1717
import org.jetbrains.kotlinx.dl.impl.summary.logSummary
1818

1919
/**
20-
* This examples demonstrates the transfer learning concept:
20+
* This example demonstrates the transfer learning concept:
2121
* - Weights are loaded from .h5 file, configuration is loaded from .json file.
22-
* - Conv2D layer are added to the new Neural Network, its weights are frozen, Dense layers are added too and its weights are not frozen, and can be changed during the training.
22+
* - Conv2D layer is added to the new Neural Network, its weights are frozen, Dense layers are added too and its weights are not frozen, and can be changed during the training.
2323
* - No new layers are added.
2424
*
2525
* NOTE: Model and weights are resources in `examples` module.

examples/src/main/kotlin/examples/transferlearning/lenet/Example_5_Additional_training_and_freezing_and_init.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ import org.jetbrains.kotlinx.dl.dataset.embedded.fashionMnist
1717
import org.jetbrains.kotlinx.dl.impl.summary.logSummary
1818

1919
/**
20-
* This examples demonstrates the transfer learning concept:
20+
* This example demonstrates the transfer learning concept:
2121
* - Weights are loaded from .h5 file for a pre-filtered list of layers (Conv2D only), configuration is loaded from .json file.
22-
* - Conv2D layer are added to the new Neural Network, its weights are frozen, Dense layers are added too and its weights are initialized via defined initializers.
22+
* - Conv2D layer is added to the new Neural Network, its weights are frozen, Dense layers are added too, and its weights are initialized via defined initializers.
2323
* - No new layers are added.
2424
*
2525
* NOTE: Model and weights are resources in `examples` module.

examples/src/main/kotlin/examples/transferlearning/lenet/Example_6_Additional_training_and_new_dense_layers.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ import org.jetbrains.kotlinx.dl.dataset.embedded.fashionMnist
2323
import org.jetbrains.kotlinx.dl.impl.summary.logSummary
2424

2525
/**
26-
* This examples demonstrates the transfer learning concept:
26+
* This example demonstrates the transfer learning concept:
2727
* - Weights are loaded from .h5 file, configuration is loaded from .json file.
28-
* - Conv2D layer are added to the new Neural Network, its weights are frozen.
28+
* - Conv2D layer is added to the new Neural Network, its weights are frozen.
2929
* - Flatten and new Dense layers are added and initialized via defined initializers.
3030
*
3131
* NOTE: Model and weights are resources in `examples` module.

examples/src/main/kotlin/examples/transferlearning/modelhub/densenet/DenseNet121.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.jetbrains.kotlinx.dl.impl.summary.logSummary
2222
import java.io.File
2323

2424
/**
25-
* This examples demonstrates the inference concept on DenseNet121 model:
25+
* This example demonstrates the inference concept on DenseNet121 model:
2626
* - Model configuration, model weights and labels are obtained from [TFModelHub].
2727
* - Weights are loaded from .h5 file, configuration is loaded from .json file.
2828
* - Model predicts on a few images located in resources.

examples/src/main/kotlin/examples/transferlearning/modelhub/densenet/DenseNet169.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.jetbrains.kotlinx.dl.impl.summary.logSummary
2424
import java.io.File
2525

2626
/**
27-
* This examples demonstrates the inference concept on DenseNet169 model:
27+
* This example demonstrates the inference concept on DenseNet169 model:
2828
* - Model configuration, model weights and labels are obtained from [TFModelHub].
2929
* - Weights are loaded from .h5 file, configuration is loaded from .json file.
3030
* - Model predicts on a few images located in resources.

0 commit comments

Comments
 (0)