Skip to content

Commit ca36627

Browse files
committed
Cleaned up interface a bit. Some javadoc
1 parent 8697456 commit ca36627

File tree

3 files changed

+394
-236
lines changed

3 files changed

+394
-236
lines changed

src/amten/ml/NeuralNetwork.java

Lines changed: 163 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,102 @@
1414
import java.util.concurrent.locks.ReentrantReadWriteLock;
1515

1616
/**
17-
* Created by Johannes Amtén on 2014-02-24.
17+
* Neural network implementation with dropout and rectified linear units.
18+
* Can perform regression or classification.
19+
* Training is done by multithreaded mini-batch gradient descent with native matrix lib.
20+
*
21+
* @author Johannes Amtén
1822
*
1923
*/
2024

2125
public class NeuralNetwork implements Serializable{
2226

2327
private Matrix[] myThetas = null;
24-
private boolean mySoftmax = true;
28+
private boolean mySoftmax = false;
2529
private double myInputLayerDropoutRate = 0.0;
2630
private double myHiddenLayersDropoutRate = 0.0;
2731

32+
private double[] myAverages = null;
33+
private double[] myStdDevs = null;
34+
private int[] myNumCategories = null;
35+
private int myNumClasses = 0;
36+
2837
private transient final ReentrantReadWriteLock myThetasLock = new ReentrantReadWriteLock();
2938
private transient ExecutorService myExecutorService;
3039

40+
/**
41+
* Create an empty Neural Network.
42+
* Use train() to generate weights.
43+
*/
3144
public NeuralNetwork() {
3245
}
3346

34-
public void train(Matrix x, Matrix y, int[] hiddenUnits, double lambda, double alpha, int batchSize, int iterations,
35-
int threads, double inputLayerDropoutRate, double hiddenLayersDropoutRate, boolean softmax, boolean debug) throws Exception {
47+
/**
48+
* Train neural network
49+
*
50+
* @param x Training data, input.
51+
* One row for each training example.
52+
* One column for each attribute.
53+
* @param numCategories Number of categories for each nominal attribute in x.
54+
* This array should have the same length as the number of columns in x.
55+
* For each nominal attribute, the value should be equal to the number of categories of that attribute.
56+
* For each numeric attribute, the value should be 1.
57+
* If numCategories is null, all attributes in x will be interpreted as numeric.
58+
* @param y Training data, correct output.
59+
* @param numClasses Number of classes, if classification.
60+
* 1 if regression.
61+
* @param hiddenUnits Number of units in each hidden layer.
62+
* @param weightPenalty L1 weight penalty.
63+
* Even when using dropout, it may be a good idea to have a small weight penalty, to keep weights down and avoid overflow.
64+
* @param learningRate Initial learning rate.
65+
* If 0, different learning rates will be tried to automatically find a good initial rate.
66+
* @param batchSize Number of examples to use in each mini-batch.
67+
* @param iterations Number of iterations (epochs) of training to perform.
68+
* Training may be halted earlier by the user, if debug flag is set.
69+
* @param threads Number of concurrent calculations.
70+
* If 0, threads will automatically be set to the number of CPU cores found.
71+
* @param inputLayerDropoutRate Dropout rate of input layer.
72+
* Typically set somewhat lower than dropout rate in hidden layer, like 0.2.
73+
* @param hiddenLayersDropoutRate Dropout rate of hidden layer.
74+
* Typically set somewhat higher than dropout rate in input layer, like 0.5.
75+
* @param debug If true, training progress will be output to the console.
76+
* Also, the user will be able to halt training by pressing enter in the console.
77+
* @param normalizeNumericData If true, will normalize the data in all numeric columns, by subtracting average and dividing by standard deviation.
78+
* @throws Exception
79+
*/
80+
public void train(Matrix x, int[] numCategories, Matrix y, int numClasses, int[] hiddenUnits, double weightPenalty, double learningRate, int batchSize, int iterations,
81+
int threads, double inputLayerDropoutRate, double hiddenLayersDropoutRate, boolean debug, boolean normalizeNumericData) throws Exception {
82+
myNumCategories = numCategories;
83+
if (myNumCategories == null) {
84+
myNumCategories = new int[x.numColumns()];
85+
Arrays.fill(myNumCategories, 1);
86+
}
87+
myNumClasses = numClasses;
88+
mySoftmax = myNumClasses > 1;
89+
90+
if (normalizeNumericData) {
91+
x = x.copy();
92+
myAverages = new double[x.numColumns()];
93+
myStdDevs = new double[x.numColumns()];
94+
for (int col = 0; col < x.numColumns(); col++) {
95+
if (myNumCategories[col] <= 1) {
96+
// Normalize numeric column.
97+
myAverages[col] = MatrixUtils.getAverage(x, col);
98+
myStdDevs[col] = MatrixUtils.getStandardDeviation(x, col);
99+
MatrixUtils.normalizeData(x, col, myAverages[col], myStdDevs[col]);
100+
}
101+
}
102+
} else {
103+
myAverages = null;
104+
myStdDevs = null;
105+
}
106+
107+
// Expand nominal values to groups of booleans.
108+
x = MatrixUtils.expandNominalAttributes(x, myNumCategories);
109+
y = MatrixUtils.expandNominalAttributes(y, new int[] {myNumClasses} );
110+
111+
36112
initThetas(x.numColumns(), hiddenUnits, y.numColumns());
37-
mySoftmax = softmax;
38113
myInputLayerDropoutRate = inputLayerDropoutRate;
39114
myHiddenLayersDropoutRate = hiddenLayersDropoutRate;
40115
// If threads == 0, use the same number of threads as cores.
@@ -47,35 +122,35 @@ public void train(Matrix x, Matrix y, int[] hiddenUnits, double lambda, double a
47122
List<Matrix> batchesX = new ArrayList<>();
48123
List<Matrix> batchesY = new ArrayList<>();
49124
MatrixUtils.split(x, y, batchSize, batchesX, batchesY);
50-
if (alpha == 0.0) {
51-
// Auto-find initial alpha.
52-
alpha = findInitialAlpha(x, y, lambda, debug);
125+
if (learningRate == 0.0) {
126+
// Auto-find initial learningRate.
127+
learningRate = findInitialLearningRate(x, y, weightPenalty, debug);
53128
}
54129

55-
double cost = getCostThreaded(batchesX, batchesY, lambda);
130+
double cost = getCostThreaded(batchesX, batchesY, weightPenalty);
56131
LinkedList<Double> oldCosts = new LinkedList<>();
57132
if (debug) {
58133
System.out.println("\n\n*** Training network. Press <enter> to halt. ***\n");
59-
System.out.println("Iteration: 0" + ", Cost: " + String.format("%.3E", cost) + ", Alpha: " + String.format("%.1E", alpha));
134+
System.out.println("Iteration: 0" + ", Cost: " + String.format("%.3E", cost) + ", Learning rate: " + String.format("%.1E", learningRate));
60135
}
61136
for (int i = 0; i < iterations && !halted; i++) {
62137
// Regenerate the batches each iteration, to get random samples each time.
63138
MatrixUtils.split(x, y, batchSize, batchesX, batchesY);
64-
trainOneIterationThreaded(batchesX, batchesY, alpha, lambda);
65-
cost = getCostThreaded(batchesX, batchesY, lambda);
139+
trainOneIterationThreaded(batchesX, batchesY, learningRate, weightPenalty);
140+
cost = getCostThreaded(batchesX, batchesY, weightPenalty);
66141

67142
if (oldCosts.size() == 5) {
68143
// Lower learning rate if cost haven't decreased for 5 iterations.
69144
double oldCost = oldCosts.remove();
70145
double minCost = Math.min(cost, Collections.min(oldCosts));
71146
if (minCost >= oldCost) {
72-
alpha = alpha*0.1;
147+
learningRate = learningRate*0.1;
73148
oldCosts.clear();
74149
}
75150
}
76151

77152
if (debug) {
78-
System.out.println("Iteration: " + (i + 1) + ", Cost: " + String.format("%.3E", cost) + ", Alpha: " + String.format("%.1E", alpha));
153+
System.out.println("Iteration: " + (i + 1) + ", Cost: " + String.format("%.3E", cost) + ", Learning rate: " + String.format("%.1E", learningRate));
79154
}
80155
oldCosts.add(cost);
81156

@@ -92,17 +167,60 @@ public void train(Matrix x, Matrix y, int[] hiddenUnits, double lambda, double a
92167
myExecutorService.shutdown();
93168
}
94169

170+
/**
171+
* Get predictions for a number of input examples.
172+
*
173+
* @param x Matrix with one row for each input example and one column for each input attribute.
174+
* @return Matrix with one row for each example.
175+
* If regression, only one column containing the predicted value.
176+
* If classification, one column for each class, containing the predicted probability of that class.
177+
*/
95178
public Matrix getPredictions(Matrix x) {
96-
Matrix[] a = feedForward(x, null);
97-
return a[a.length-1];
179+
if (myAverages != null) {
180+
x = x.copy();
181+
for (int col = 0; col < x.numColumns(); col++) {
182+
if (myNumCategories[col] <= 1) {
183+
// Normalize numeric column.
184+
MatrixUtils.normalizeData(x, col, myAverages[col], myStdDevs[col]);
185+
}
186+
}
187+
}
188+
// Expand nominal values to groups of booleans.
189+
x = MatrixUtils.expandNominalAttributes(x, myNumCategories);
190+
191+
Matrix[] activations = feedForward(x, null);
192+
return activations[activations.length-1];
98193
}
99194

100-
public double[] getPredictions(double[] x) {
101-
Matrix xMatrix = new Matrix(new double[][]{x});
102-
Matrix[] a = feedForward(xMatrix, null);
103-
return a[a.length-1].getData();
195+
/**
196+
* Get classification predictions for a number of input examples.
197+
*
198+
* @param x Matrix with one row for each input example and one column for each input attribute.
199+
* @return Matrix with one row for each example and one column containing the predicted class.
200+
*/
201+
public int[] getPredictedClasses(Matrix x) {
202+
Matrix y = getPredictions(x);
203+
int[] predictedClasses = new int[x.numRows()];
204+
for (int row = 0; row < y.numRows(); row++) {
205+
int prediction = 0;
206+
double predMaxValue = Double.MIN_VALUE;
207+
for (int col = 0; col < y.numColumns(); col++) {
208+
if (y.get(row, col) > predMaxValue) {
209+
predMaxValue = y.get(row, col);
210+
prediction = col;
211+
}
212+
}
213+
predictedClasses[row] = prediction;
214+
}
215+
return predictedClasses;
104216
}
105217

218+
// public double[] getPredictions(double[] x) {
219+
// Matrix xMatrix = new Matrix(new double[][]{x});
220+
// Matrix[] a = feedForward(xMatrix, null);
221+
// return a[a.length-1].getData();
222+
// }
223+
106224
private void initThetas(int inputs, int[] hidden, int outputs) {
107225
ArrayList<Integer> numNodes = new ArrayList<>();
108226
numNodes.add(inputs);
@@ -176,7 +294,7 @@ private int numberOfNodes() {
176294
return nodes;
177295
}
178296

179-
private double getCost(Matrix x, Matrix y, double lambda, int batchSize) {
297+
private double getCost(Matrix x, Matrix y, double weightPenalty, int batchSize) {
180298
double c = 0.0;
181299

182300
Matrix[] a = feedForward(x, null);
@@ -204,21 +322,21 @@ private double getCost(Matrix x, Matrix y, double lambda, int batchSize) {
204322
c = c/(2*batchSize);
205323
}
206324

207-
if (lambda > 0) {
325+
if (weightPenalty > 0) {
208326
// Regularization
209327
double regSum = 0.0;
210328
for (Matrix theta:myThetas) {
211329
for (MatrixElement me: theta.getColumns(1, -1)) {
212330
regSum += Math.abs(me.value());
213331
}
214332
}
215-
c += regSum*lambda/numberOfNodes();
333+
c += regSum*weightPenalty/numberOfNodes();
216334
}
217335

218336
return c;
219337
}
220338

221-
private double getCostThreaded(List<Matrix> batchesX, List<Matrix> batchesY, final double lambda) throws Exception {
339+
private double getCostThreaded(List<Matrix> batchesX, List<Matrix> batchesY, final double weightPenalty) throws Exception {
222340
final int batchSize = batchesX.get(0).numRows();
223341
// Queue up cost calculation in thread pool
224342
List<Future<Double>> costJobs = new ArrayList<>();
@@ -228,7 +346,7 @@ private double getCostThreaded(List<Matrix> batchesX, List<Matrix> batchesY, fin
228346
Callable<Double> costCalculator = new Callable<Double>() {
229347
public Double call() throws Exception {
230348
myThetasLock.readLock().lock();
231-
double cost = getCost(bx, by, lambda, batchSize);
349+
double cost = getCost(bx, by, weightPenalty, batchSize);
232350
myThetasLock.readLock().unlock();
233351
return cost;
234352
}
@@ -245,7 +363,7 @@ public Double call() throws Exception {
245363
return cost;
246364
}
247365

248-
private Matrix[] getGradients(Matrix x, Matrix y, double lambda, Matrix[] dropoutMasks, int batchSize) {
366+
private Matrix[] getGradients(Matrix x, Matrix y, double weightPenalty, Matrix[] dropoutMasks, int batchSize) {
249367

250368
int numLayers = myThetas.length+1;
251369

@@ -278,15 +396,15 @@ private Matrix[] getGradients(Matrix x, Matrix y, double lambda, Matrix[] dropou
278396
thetaGrad[layer].scale(1.0/batchSize);
279397
}
280398

281-
if (lambda > 0) {
399+
if (weightPenalty > 0) {
282400
// Add regularization terms
283401
int numNodes = numberOfNodes();
284402
for (int thetaNr = 0; thetaNr < numLayers-1 ; thetaNr++) {
285403
Matrix theta = myThetas[thetaNr];
286404
Matrix grad = thetaGrad[thetaNr];
287405
for (int row = 0; row < grad.numRows() ; row++) {
288406
for (int col = 1; col < grad.numColumns(); col++) {
289-
double regTerm = lambda/numNodes*Math.signum(theta.get(row, col));
407+
double regTerm = weightPenalty/numNodes*Math.signum(theta.get(row, col));
290408
grad.add(row, col, regTerm);
291409
}
292410
}
@@ -312,7 +430,7 @@ private Matrix[] generateDropoutMasks(int examples) {
312430
return masks;
313431
}
314432

315-
private void trainOneIterationThreaded(List<Matrix> batchesX, List<Matrix> batchesY, final double alpha, final double lambda) throws Exception {
433+
private void trainOneIterationThreaded(List<Matrix> batchesX, List<Matrix> batchesY, final double learningRate, final double weightPenalty) throws Exception {
316434
final int batchSize = batchesX.get(0).numRows();
317435

318436
// Queue up all batches for gradient computation in the thread pool.
@@ -326,11 +444,11 @@ public void run() {
326444
boolean useDropout = myInputLayerDropoutRate > 0.0 || myHiddenLayersDropoutRate > 0.0;
327445
Matrix[] dropoutMasks = useDropout ? generateDropoutMasks(by.numRows()) : null;
328446
myThetasLock.readLock().lock();
329-
Matrix[] gradients = getGradients(bx, by, lambda, dropoutMasks, batchSize);
447+
Matrix[] gradients = getGradients(bx, by, weightPenalty, dropoutMasks, batchSize);
330448
myThetasLock.readLock().unlock();
331449
myThetasLock.writeLock().lock();
332450
for (int theta = 0; theta < myThetas.length; theta++) {
333-
myThetas[theta].add(-alpha, gradients[theta]);
451+
myThetas[theta].add(-learningRate, gradients[theta]);
334452
}
335453
myThetasLock.writeLock().unlock();
336454
}
@@ -352,7 +470,7 @@ private Matrix[] deepCopy(Matrix[] ms) {
352470
return res;
353471
}
354472

355-
private double findInitialAlpha(Matrix x, Matrix y, double lambda, boolean debug) throws Exception {
473+
private double findInitialLearningRate(Matrix x, Matrix y, double weightPenalty, boolean debug) throws Exception {
356474
int numUsedTrainingExamples = 5000;
357475
int batchSize = 100;
358476
int numBatches = numUsedTrainingExamples/batchSize;
@@ -370,32 +488,32 @@ private double findInitialAlpha(Matrix x, Matrix y, double lambda, boolean debug
370488
}
371489

372490
Matrix[] startThetas = deepCopy(myThetas);
373-
double alpha = 1.0E-10;
374-
trainOneIterationThreaded(batchesX, batchesY, alpha, lambda);
375-
double cost = getCostThreaded(batchesX, batchesY, lambda);
491+
double lr = 1.0E-10;
492+
trainOneIterationThreaded(batchesX, batchesY, lr, weightPenalty);
493+
double cost = getCostThreaded(batchesX, batchesY, weightPenalty);
376494
if (debug) {
377-
System.out.println("\n\nAuto-finding learning rate, alpha");
378-
System.out.println("Alpha: " + String.format("%.1E", alpha) + " Cost: " + cost); ////////////////////////////
495+
System.out.println("\n\nAuto-finding learning rate.");
496+
System.out.println("Learning rate: " + String.format("%.1E", lr) + " Cost: " + cost); ////////////////////////////
379497
}
380498
myThetas = deepCopy(startThetas);
381499
double lastCost = Double.MAX_VALUE;
382-
double lastAlpha = alpha;
500+
double lastLR = lr;
383501
while (cost < lastCost) {
384502
lastCost = cost;
385-
lastAlpha = alpha;
386-
alpha = alpha*10.0;
387-
trainOneIterationThreaded(batchesX, batchesY, alpha, lambda);
388-
cost = getCostThreaded(batchesX, batchesY, lambda);
503+
lastLR = lr;
504+
lr = lr*10.0;
505+
trainOneIterationThreaded(batchesX, batchesY, lr, weightPenalty);
506+
cost = getCostThreaded(batchesX, batchesY, weightPenalty);
389507
if (debug) {
390-
System.out.println("Alpha: " + String.format("%.1E", alpha) + " Cost: " + cost); ////////////////////////////
508+
System.out.println("Learning rate: " + String.format("%.1E", lr) + " Cost: " + cost); ////////////////////////////
391509
}
392510
myThetas = deepCopy(startThetas);
393511
}
394512

395513
if (debug) {
396-
System.out.println("Using alpha: " + String.format("%.1E", lastAlpha));
514+
System.out.println("Using learning rate: " + String.format("%.1E", lastLR));
397515
}
398-
return lastAlpha;
516+
return lastLR;
399517
}
400518

401519

0 commit comments

Comments
 (0)