1414import java .util .concurrent .locks .ReentrantReadWriteLock ;
1515
1616/**
17- * Created by Johannes Amtén on 2014-02-24.
17+ * Neural network implementation with dropout and rectified linear units.
18+ * Can perform regression or classification.
19+ * Training is done by multithreaded mini-batch gradient descent with native matrix lib.
20+ *
21+ * @author Johannes Amtén
1822 *
1923 */
2024
2125public class NeuralNetwork implements Serializable {
2226
2327 private Matrix [] myThetas = null ;
24- private boolean mySoftmax = true ;
28+ private boolean mySoftmax = false ;
2529 private double myInputLayerDropoutRate = 0.0 ;
2630 private double myHiddenLayersDropoutRate = 0.0 ;
2731
32+ private double [] myAverages = null ;
33+ private double [] myStdDevs = null ;
34+ private int [] myNumCategories = null ;
35+ private int myNumClasses = 0 ;
36+
2837 private transient final ReentrantReadWriteLock myThetasLock = new ReentrantReadWriteLock ();
2938 private transient ExecutorService myExecutorService ;
3039
40+ /**
41+ * Create an empty Neural Network.
42+ * Use train() to generate weights.
43+ */
3144 public NeuralNetwork () {
3245 }
3346
34- public void train (Matrix x , Matrix y , int [] hiddenUnits , double lambda , double alpha , int batchSize , int iterations ,
35- int threads , double inputLayerDropoutRate , double hiddenLayersDropoutRate , boolean softmax , boolean debug ) throws Exception {
47+ /**
48+ * Train neural network
49+ *
50+ * @param x Training data, input.
51+ * One row for each training example.
52+ * One column for each attribute.
53+ * @param numCategories Number of categories for each nominal attribute in x.
54+ * This array should have the same length as the number of columns in x.
55+ * For each nominal attribute, the value should be equal to the number of categories of that attribute.
56+ * For each numeric attribute, the value should be 1.
57+ * If numCategories is null, all attributes in x will be interpreted as numeric.
58+ * @param y Training data, correct output.
59+ * @param numClasses Number of classes, if classification.
60+ * 1 if regression.
61+ * @param hiddenUnits Number of units in each hidden layer.
62+ * @param weightPenalty L1 weight penalty.
63+ * Even when using dropout, it may be a good idea to have a small weight penalty, to keep weights down and avoid overflow.
64+ * @param learningRate Initial learning rate.
65+ * If 0, different learning rates will be tried to automatically find a good initial rate.
66+ * @param batchSize Number of examples to use in each mini-batch.
67+ * @param iterations Number of iterations (epochs) of training to perform.
68+ * Training may be halted earlier by the user, if debug flag is set.
69+ * @param threads Number of concurrent calculations.
70+ * If 0, threads will automatically be set to the number of CPU cores found.
71+ * @param inputLayerDropoutRate Dropout rate of input layer.
72+ * Typically set somewhat lower than dropout rate in hidden layer, like 0.2.
73+ * @param hiddenLayersDropoutRate Dropout rate of hidden layer.
74+ * Typically set somewhat higher than dropout rate in input layer, like 0.5.
75+ * @param debug If true, training progress will be output to the console.
76+ * Also, the user will be able to halt training by pressing enter in the console.
77+ * @param normalizeNumericData If true, will normalize the data in all numeric columns, by subtracting average and dividing by standard deviation.
78+ * @throws Exception
79+ */
80+ public void train (Matrix x , int [] numCategories , Matrix y , int numClasses , int [] hiddenUnits , double weightPenalty , double learningRate , int batchSize , int iterations ,
81+ int threads , double inputLayerDropoutRate , double hiddenLayersDropoutRate , boolean debug , boolean normalizeNumericData ) throws Exception {
82+ myNumCategories = numCategories ;
83+ if (myNumCategories == null ) {
84+ myNumCategories = new int [x .numColumns ()];
85+ Arrays .fill (myNumCategories , 1 );
86+ }
87+ myNumClasses = numClasses ;
88+ mySoftmax = myNumClasses > 1 ;
89+
90+ if (normalizeNumericData ) {
91+ x = x .copy ();
92+ myAverages = new double [x .numColumns ()];
93+ myStdDevs = new double [x .numColumns ()];
94+ for (int col = 0 ; col < x .numColumns (); col ++) {
95+ if (myNumCategories [col ] <= 1 ) {
96+ // Normalize numeric column.
97+ myAverages [col ] = MatrixUtils .getAverage (x , col );
98+ myStdDevs [col ] = MatrixUtils .getStandardDeviation (x , col );
99+ MatrixUtils .normalizeData (x , col , myAverages [col ], myStdDevs [col ]);
100+ }
101+ }
102+ } else {
103+ myAverages = null ;
104+ myStdDevs = null ;
105+ }
106+
107+ // Expand nominal values to groups of booleans.
108+ x = MatrixUtils .expandNominalAttributes (x , myNumCategories );
109+ y = MatrixUtils .expandNominalAttributes (y , new int [] {myNumClasses } );
110+
111+
36112 initThetas (x .numColumns (), hiddenUnits , y .numColumns ());
37- mySoftmax = softmax ;
38113 myInputLayerDropoutRate = inputLayerDropoutRate ;
39114 myHiddenLayersDropoutRate = hiddenLayersDropoutRate ;
40115 // If threads == 0, use the same number of threads as cores.
@@ -47,35 +122,35 @@ public void train(Matrix x, Matrix y, int[] hiddenUnits, double lambda, double a
47122 List <Matrix > batchesX = new ArrayList <>();
48123 List <Matrix > batchesY = new ArrayList <>();
49124 MatrixUtils .split (x , y , batchSize , batchesX , batchesY );
50- if (alpha == 0.0 ) {
51- // Auto-find initial alpha .
52- alpha = findInitialAlpha (x , y , lambda , debug );
125+ if (learningRate == 0.0 ) {
126+ // Auto-find initial learningRate .
127+ learningRate = findInitialLearningRate (x , y , weightPenalty , debug );
53128 }
54129
55- double cost = getCostThreaded (batchesX , batchesY , lambda );
130+ double cost = getCostThreaded (batchesX , batchesY , weightPenalty );
56131 LinkedList <Double > oldCosts = new LinkedList <>();
57132 if (debug ) {
58133 System .out .println ("\n \n *** Training network. Press <enter> to halt. ***\n " );
59- System .out .println ("Iteration: 0" + ", Cost: " + String .format ("%.3E" , cost ) + ", Alpha : " + String .format ("%.1E" , alpha ));
134+ System .out .println ("Iteration: 0" + ", Cost: " + String .format ("%.3E" , cost ) + ", Learning rate : " + String .format ("%.1E" , learningRate ));
60135 }
61136 for (int i = 0 ; i < iterations && !halted ; i ++) {
62137 // Regenerate the batches each iteration, to get random samples each time.
63138 MatrixUtils .split (x , y , batchSize , batchesX , batchesY );
64- trainOneIterationThreaded (batchesX , batchesY , alpha , lambda );
65- cost = getCostThreaded (batchesX , batchesY , lambda );
139+ trainOneIterationThreaded (batchesX , batchesY , learningRate , weightPenalty );
140+ cost = getCostThreaded (batchesX , batchesY , weightPenalty );
66141
67142 if (oldCosts .size () == 5 ) {
68143 // Lower learning rate if cost haven't decreased for 5 iterations.
69144 double oldCost = oldCosts .remove ();
70145 double minCost = Math .min (cost , Collections .min (oldCosts ));
71146 if (minCost >= oldCost ) {
72- alpha = alpha *0.1 ;
147+ learningRate = learningRate *0.1 ;
73148 oldCosts .clear ();
74149 }
75150 }
76151
77152 if (debug ) {
78- System .out .println ("Iteration: " + (i + 1 ) + ", Cost: " + String .format ("%.3E" , cost ) + ", Alpha : " + String .format ("%.1E" , alpha ));
153+ System .out .println ("Iteration: " + (i + 1 ) + ", Cost: " + String .format ("%.3E" , cost ) + ", Learning rate : " + String .format ("%.1E" , learningRate ));
79154 }
80155 oldCosts .add (cost );
81156
@@ -92,17 +167,60 @@ public void train(Matrix x, Matrix y, int[] hiddenUnits, double lambda, double a
92167 myExecutorService .shutdown ();
93168 }
94169
170+ /**
171+ * Get predictions for a number of input examples.
172+ *
173+ * @param x Matrix with one row for each input example and one column for each input attribute.
174+ * @return Matrix with one row for each example.
175+ * If regression, only one column containing the predicted value.
176+ * If classification, one column for each class, containing the predicted probability of that class.
177+ */
95178 public Matrix getPredictions (Matrix x ) {
96- Matrix [] a = feedForward (x , null );
97- return a [a .length -1 ];
179+ if (myAverages != null ) {
180+ x = x .copy ();
181+ for (int col = 0 ; col < x .numColumns (); col ++) {
182+ if (myNumCategories [col ] <= 1 ) {
183+ // Normalize numeric column.
184+ MatrixUtils .normalizeData (x , col , myAverages [col ], myStdDevs [col ]);
185+ }
186+ }
187+ }
188+ // Expand nominal values to groups of booleans.
189+ x = MatrixUtils .expandNominalAttributes (x , myNumCategories );
190+
191+ Matrix [] activations = feedForward (x , null );
192+ return activations [activations .length -1 ];
98193 }
99194
100- public double [] getPredictions (double [] x ) {
101- Matrix xMatrix = new Matrix (new double [][]{x });
102- Matrix [] a = feedForward (xMatrix , null );
103- return a [a .length -1 ].getData ();
195+ /**
196+ * Get classification predictions for a number of input examples.
197+ *
198+ * @param x Matrix with one row for each input example and one column for each input attribute.
199+ * @return Matrix with one row for each example and one column containing the predicted class.
200+ */
201+ public int [] getPredictedClasses (Matrix x ) {
202+ Matrix y = getPredictions (x );
203+ int [] predictedClasses = new int [x .numRows ()];
204+ for (int row = 0 ; row < y .numRows (); row ++) {
205+ int prediction = 0 ;
206+ double predMaxValue = Double .MIN_VALUE ;
207+ for (int col = 0 ; col < y .numColumns (); col ++) {
208+ if (y .get (row , col ) > predMaxValue ) {
209+ predMaxValue = y .get (row , col );
210+ prediction = col ;
211+ }
212+ }
213+ predictedClasses [row ] = prediction ;
214+ }
215+ return predictedClasses ;
104216 }
105217
218+ // public double[] getPredictions(double[] x) {
219+ // Matrix xMatrix = new Matrix(new double[][]{x});
220+ // Matrix[] a = feedForward(xMatrix, null);
221+ // return a[a.length-1].getData();
222+ // }
223+
106224 private void initThetas (int inputs , int [] hidden , int outputs ) {
107225 ArrayList <Integer > numNodes = new ArrayList <>();
108226 numNodes .add (inputs );
@@ -176,7 +294,7 @@ private int numberOfNodes() {
176294 return nodes ;
177295 }
178296
179- private double getCost (Matrix x , Matrix y , double lambda , int batchSize ) {
297+ private double getCost (Matrix x , Matrix y , double weightPenalty , int batchSize ) {
180298 double c = 0.0 ;
181299
182300 Matrix [] a = feedForward (x , null );
@@ -204,21 +322,21 @@ private double getCost(Matrix x, Matrix y, double lambda, int batchSize) {
204322 c = c /(2 *batchSize );
205323 }
206324
207- if (lambda > 0 ) {
325+ if (weightPenalty > 0 ) {
208326 // Regularization
209327 double regSum = 0.0 ;
210328 for (Matrix theta :myThetas ) {
211329 for (MatrixElement me : theta .getColumns (1 , -1 )) {
212330 regSum += Math .abs (me .value ());
213331 }
214332 }
215- c += regSum *lambda /numberOfNodes ();
333+ c += regSum *weightPenalty /numberOfNodes ();
216334 }
217335
218336 return c ;
219337 }
220338
221- private double getCostThreaded (List <Matrix > batchesX , List <Matrix > batchesY , final double lambda ) throws Exception {
339+ private double getCostThreaded (List <Matrix > batchesX , List <Matrix > batchesY , final double weightPenalty ) throws Exception {
222340 final int batchSize = batchesX .get (0 ).numRows ();
223341 // Queue up cost calculation in thread pool
224342 List <Future <Double >> costJobs = new ArrayList <>();
@@ -228,7 +346,7 @@ private double getCostThreaded(List<Matrix> batchesX, List<Matrix> batchesY, fin
228346 Callable <Double > costCalculator = new Callable <Double >() {
229347 public Double call () throws Exception {
230348 myThetasLock .readLock ().lock ();
231- double cost = getCost (bx , by , lambda , batchSize );
349+ double cost = getCost (bx , by , weightPenalty , batchSize );
232350 myThetasLock .readLock ().unlock ();
233351 return cost ;
234352 }
@@ -245,7 +363,7 @@ public Double call() throws Exception {
245363 return cost ;
246364 }
247365
248- private Matrix [] getGradients (Matrix x , Matrix y , double lambda , Matrix [] dropoutMasks , int batchSize ) {
366+ private Matrix [] getGradients (Matrix x , Matrix y , double weightPenalty , Matrix [] dropoutMasks , int batchSize ) {
249367
250368 int numLayers = myThetas .length +1 ;
251369
@@ -278,15 +396,15 @@ private Matrix[] getGradients(Matrix x, Matrix y, double lambda, Matrix[] dropou
278396 thetaGrad [layer ].scale (1.0 /batchSize );
279397 }
280398
281- if (lambda > 0 ) {
399+ if (weightPenalty > 0 ) {
282400 // Add regularization terms
283401 int numNodes = numberOfNodes ();
284402 for (int thetaNr = 0 ; thetaNr < numLayers -1 ; thetaNr ++) {
285403 Matrix theta = myThetas [thetaNr ];
286404 Matrix grad = thetaGrad [thetaNr ];
287405 for (int row = 0 ; row < grad .numRows () ; row ++) {
288406 for (int col = 1 ; col < grad .numColumns (); col ++) {
289- double regTerm = lambda /numNodes *Math .signum (theta .get (row , col ));
407+ double regTerm = weightPenalty /numNodes *Math .signum (theta .get (row , col ));
290408 grad .add (row , col , regTerm );
291409 }
292410 }
@@ -312,7 +430,7 @@ private Matrix[] generateDropoutMasks(int examples) {
312430 return masks ;
313431 }
314432
315- private void trainOneIterationThreaded (List <Matrix > batchesX , List <Matrix > batchesY , final double alpha , final double lambda ) throws Exception {
433+ private void trainOneIterationThreaded (List <Matrix > batchesX , List <Matrix > batchesY , final double learningRate , final double weightPenalty ) throws Exception {
316434 final int batchSize = batchesX .get (0 ).numRows ();
317435
318436 // Queue up all batches for gradient computation in the thread pool.
@@ -326,11 +444,11 @@ public void run() {
326444 boolean useDropout = myInputLayerDropoutRate > 0.0 || myHiddenLayersDropoutRate > 0.0 ;
327445 Matrix [] dropoutMasks = useDropout ? generateDropoutMasks (by .numRows ()) : null ;
328446 myThetasLock .readLock ().lock ();
329- Matrix [] gradients = getGradients (bx , by , lambda , dropoutMasks , batchSize );
447+ Matrix [] gradients = getGradients (bx , by , weightPenalty , dropoutMasks , batchSize );
330448 myThetasLock .readLock ().unlock ();
331449 myThetasLock .writeLock ().lock ();
332450 for (int theta = 0 ; theta < myThetas .length ; theta ++) {
333- myThetas [theta ].add (-alpha , gradients [theta ]);
451+ myThetas [theta ].add (-learningRate , gradients [theta ]);
334452 }
335453 myThetasLock .writeLock ().unlock ();
336454 }
@@ -352,7 +470,7 @@ private Matrix[] deepCopy(Matrix[] ms) {
352470 return res ;
353471 }
354472
355- private double findInitialAlpha (Matrix x , Matrix y , double lambda , boolean debug ) throws Exception {
473+ private double findInitialLearningRate (Matrix x , Matrix y , double weightPenalty , boolean debug ) throws Exception {
356474 int numUsedTrainingExamples = 5000 ;
357475 int batchSize = 100 ;
358476 int numBatches = numUsedTrainingExamples /batchSize ;
@@ -370,32 +488,32 @@ private double findInitialAlpha(Matrix x, Matrix y, double lambda, boolean debug
370488 }
371489
372490 Matrix [] startThetas = deepCopy (myThetas );
373- double alpha = 1.0E-10 ;
374- trainOneIterationThreaded (batchesX , batchesY , alpha , lambda );
375- double cost = getCostThreaded (batchesX , batchesY , lambda );
491+ double lr = 1.0E-10 ;
492+ trainOneIterationThreaded (batchesX , batchesY , lr , weightPenalty );
493+ double cost = getCostThreaded (batchesX , batchesY , weightPenalty );
376494 if (debug ) {
377- System .out .println ("\n \n Auto-finding learning rate, alpha " );
378- System .out .println ("Alpha : " + String .format ("%.1E" , alpha ) + " Cost: " + cost ); ////////////////////////////
495+ System .out .println ("\n \n Auto-finding learning rate. " );
496+ System .out .println ("Learning rate : " + String .format ("%.1E" , lr ) + " Cost: " + cost ); ////////////////////////////
379497 }
380498 myThetas = deepCopy (startThetas );
381499 double lastCost = Double .MAX_VALUE ;
382- double lastAlpha = alpha ;
500+ double lastLR = lr ;
383501 while (cost < lastCost ) {
384502 lastCost = cost ;
385- lastAlpha = alpha ;
386- alpha = alpha *10.0 ;
387- trainOneIterationThreaded (batchesX , batchesY , alpha , lambda );
388- cost = getCostThreaded (batchesX , batchesY , lambda );
503+ lastLR = lr ;
504+ lr = lr *10.0 ;
505+ trainOneIterationThreaded (batchesX , batchesY , lr , weightPenalty );
506+ cost = getCostThreaded (batchesX , batchesY , weightPenalty );
389507 if (debug ) {
390- System .out .println ("Alpha : " + String .format ("%.1E" , alpha ) + " Cost: " + cost ); ////////////////////////////
508+ System .out .println ("Learning rate : " + String .format ("%.1E" , lr ) + " Cost: " + cost ); ////////////////////////////
391509 }
392510 myThetas = deepCopy (startThetas );
393511 }
394512
395513 if (debug ) {
396- System .out .println ("Using alpha : " + String .format ("%.1E" , lastAlpha ));
514+ System .out .println ("Using learning rate : " + String .format ("%.1E" , lastLR ));
397515 }
398- return lastAlpha ;
516+ return lastLR ;
399517 }
400518
401519
0 commit comments