Skip to content

Commit 6c42621

Browse files
committed
Various improvements
Added possibility of having multiple input channels for convolution (e.g. color images). Load data on demand for training data sets that are too large to fit in memory. Added convergence test for auto-stopping training. Threaded getPredictions().
1 parent a6e0112 commit 6c42621

File tree

7 files changed

+164
-95
lines changed

7 files changed

+164
-95
lines changed

src/amten/ml/Convolutions.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public static Matrix antiPoolDelta(Matrix delta, Matrix prePoolRowIndexes, int n
6262
return result;
6363
}
6464

65-
public static Matrix generatePatchesFromInputLayer(Matrix inputs, int inputWidth, int patchWidth, int patchHeight) {
65+
public static Matrix generatePatchesFromInputLayer(Matrix inputs, int inputWidth, int inputHeight, int patchWidth, int patchHeight) {
6666
// Input data has one row per example.
6767
// Input data has one column per pixel.
6868
// Assumes pixel-numbers are generated row-wise.
@@ -71,20 +71,22 @@ public static Matrix generatePatchesFromInputLayer(Matrix inputs, int inputWidth
7171
// Output data have one row per example/patch
7272
// Output data have one column per patchPixel.
7373

74-
int inputHeight = inputs.numColumns() / inputWidth;
74+
int numChannels = inputs.numColumns() / (inputWidth*inputHeight);
7575
int numPatchesPerExample = (inputWidth-patchWidth+1)*(inputHeight-patchHeight+1);
7676
int numExamples = inputs.numRows();
77-
Matrix output = new Matrix(numExamples*numPatchesPerExample, patchWidth*patchHeight);
77+
Matrix output = new Matrix(numExamples*numPatchesPerExample, numChannels*patchWidth*patchHeight);
7878
for (int example = 0; example < numExamples; example++) {
7979
int patchNum = 0;
8080
for (int inputStartY = 0; inputStartY < inputHeight-patchHeight+1; inputStartY++) {
8181
for (int inputStartX = 0; inputStartX < inputWidth-patchWidth+1; inputStartX++) {
82-
for (int patchPixelY = 0; patchPixelY < patchHeight; patchPixelY++) {
83-
for (int patchPixelX = 0; patchPixelX < patchWidth; patchPixelX++) {
84-
int inputY = inputStartY + patchPixelY;
85-
int inputX = inputStartX + patchPixelX;
86-
double value = inputs.get(example, inputY*inputWidth + inputX);
87-
output.set(example*numPatchesPerExample + patchNum, patchPixelY*patchWidth + patchPixelX, value);
82+
for (int channel = 0; channel < numChannels; channel++) {
83+
for (int patchPixelY = 0; patchPixelY < patchHeight; patchPixelY++) {
84+
for (int patchPixelX = 0; patchPixelX < patchWidth; patchPixelX++) {
85+
int inputY = inputStartY + patchPixelY;
86+
int inputX = inputStartX + patchPixelX;
87+
double value = inputs.get(example, channel*inputHeight*inputWidth + inputY*inputWidth + inputX);
88+
output.set(example*numPatchesPerExample + patchNum, channel*patchHeight*patchWidth + patchPixelY*patchWidth + patchPixelX, value);
89+
}
8890
}
8991
}
9092
patchNum++;

src/amten/ml/DataLoader.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package amten.ml;
2+
3+
import amten.ml.matrix.Matrix;
4+
5+
import java.io.IOException;
6+
7+
/**
8+
* Interface to be implemented by a class that loads training data from disk on-demand.
9+
* Implement when training dataset is too large to fit in memory.
10+
*
11+
* @author Johannes Amtén
12+
*/
13+
public interface DataLoader {
14+
15+
/**
16+
*
17+
* @param xIDs Some type of id numbers for the datapoints to load.
18+
* @return Training data for the specified ids
19+
* @throws IOException
20+
*/
21+
public Matrix loadData(Matrix xIDs) throws IOException;
22+
23+
/**
24+
*
25+
* @return Number of values for each datapoint
26+
*/
27+
public int getDataSize();
28+
29+
}

0 commit comments

Comments
 (0)