Skip to content

Commit 3404f32

Browse files
committed
Auto-choose batchsize
1 parent c8b415a commit 3404f32

File tree

5 files changed

+10
-5
lines changed

5 files changed

+10
-5
lines changed

src/amten/ml/NNParams.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ public NNParams() {
8787
* Number of examples to use in each mini-batch.
8888
* Batch-size 100 is a good choice for fully connected networks.
8989
* Batch-size 1 is a good choice for convolutional networks.
90+
* If set to 0, a batch-size of 100 or 1 will be used depending on whether the network is fully connected or convolutional.
9091
*
91-
* Default is 100.
92+
* Default is 0.
9293
*/
93-
public int batchSize = 100;
94+
public int batchSize = 0;
9495

9596

9697
/**

src/amten/ml/NeuralNetwork.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ public void train(Matrix x, Matrix y) throws Exception {
113113
myParams.inputWidth = inputSize;
114114
}
115115

116+
if (myParams.batchSize == 0) {
117+
// Auto-choose batch-size.
118+
// 100 for fully connected network and 1 for convolutional network.
119+
myParams.batchSize = myLayerParams[1].isConvolutional() ? 1 : 100;
120+
}
121+
116122
initThetas();
117123
// If threads == 0, use the same number of threads as cores.
118124
myParams.numThreads = myParams.numThreads > 0 ? myParams.numThreads : Runtime.getRuntime().availableProcessors();

src/amten/ml/examples/NNClassificationExample.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ public static void runKaggleDigitsClassification(boolean useConvolution) throws
4646
params.numClasses = 10; // 10 digits to classify
4747
params.hiddenLayerParams = useConvolution ? new NNParams.NNLayerParams[]{ new NNParams.NNLayerParams(20, 5, 5, 2, 2) , new NNParams.NNLayerParams(100, 5, 5, 2, 2) } :
4848
new NNParams.NNLayerParams[] { new NNParams.NNLayerParams(100) };
49-
params.batchSize = useConvolution ? 1 : 100;
5049
params.maxIterations = useConvolution ? 10 : 200;
5150
params.learningRate = useConvolution ? 1E-2 : 0;
5251

src/amten/ml/test/NeuralNetworkTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ public static void runKaggleDigitsClassification() throws Exception {
4141
params.numClasses = 10; // 10 digits to classify
4242
params.hiddenLayerParams = new NNParams.NNLayerParams[] { new NNParams.NNLayerParams(20, 5, 5, 2, 2) , new NNParams.NNLayerParams(100, 5, 5, 2, 2) };
4343
params.learningRate = 1E-2;
44-
params.batchSize = 1;
4544
params.maxIterations = 10;
4645

4746
long startTime = System.currentTimeMillis();

src/weka/classifiers/functions/NeuralNetwork.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ public void setBatchSize(int batchSize) {
250250
myParams.batchSize = batchSize;
251251
}
252252
public String batchSizeTipText() {
253-
return "Number of training examples in each mini-batch (=1 recommended for convolutional networks) .";
253+
return "Number of training examples in each mini-batch (0=Auto-choose) .";
254254
}
255255

256256
public int getThreads() {

0 commit comments

Comments
 (0)