@@ -62,7 +62,7 @@ public static Matrix antiPoolDelta(Matrix delta, Matrix prePoolRowIndexes, int n
6262 return result ;
6363 }
6464
65- public static Matrix generatePatchesFromInputLayer (Matrix inputs , int inputWidth , int patchWidth , int patchHeight ) {
65+ public static Matrix generatePatchesFromInputLayer (Matrix inputs , int inputWidth , int inputHeight , int patchWidth , int patchHeight ) {
6666 // Input data has one row per example.
6767 // Input data has one column per pixel.
6868 // Assumes pixel-numbers are generated row-wise.
@@ -71,20 +71,22 @@ public static Matrix generatePatchesFromInputLayer(Matrix inputs, int inputWidth
7171 // Output data have one row per example/patch
7272 // Output data have one column per patchPixel.
7373
74- int inputHeight = inputs .numColumns () / inputWidth ;
74+ int numChannels = inputs .numColumns () / ( inputWidth * inputHeight ) ;
7575 int numPatchesPerExample = (inputWidth -patchWidth +1 )*(inputHeight -patchHeight +1 );
7676 int numExamples = inputs .numRows ();
77- Matrix output = new Matrix (numExamples *numPatchesPerExample , patchWidth *patchHeight );
77+ Matrix output = new Matrix (numExamples *numPatchesPerExample , numChannels * patchWidth *patchHeight );
7878 for (int example = 0 ; example < numExamples ; example ++) {
7979 int patchNum = 0 ;
8080 for (int inputStartY = 0 ; inputStartY < inputHeight -patchHeight +1 ; inputStartY ++) {
8181 for (int inputStartX = 0 ; inputStartX < inputWidth -patchWidth +1 ; inputStartX ++) {
82- for (int patchPixelY = 0 ; patchPixelY < patchHeight ; patchPixelY ++) {
83- for (int patchPixelX = 0 ; patchPixelX < patchWidth ; patchPixelX ++) {
84- int inputY = inputStartY + patchPixelY ;
85- int inputX = inputStartX + patchPixelX ;
86- double value = inputs .get (example , inputY *inputWidth + inputX );
87- output .set (example *numPatchesPerExample + patchNum , patchPixelY *patchWidth + patchPixelX , value );
82+ for (int channel = 0 ; channel < numChannels ; channel ++) {
83+ for (int patchPixelY = 0 ; patchPixelY < patchHeight ; patchPixelY ++) {
84+ for (int patchPixelX = 0 ; patchPixelX < patchWidth ; patchPixelX ++) {
85+ int inputY = inputStartY + patchPixelY ;
86+ int inputX = inputStartX + patchPixelX ;
87+ double value = inputs .get (example , channel *inputHeight *inputWidth + inputY *inputWidth + inputX );
88+ output .set (example *numPatchesPerExample + patchNum , channel *patchHeight *patchWidth + patchPixelY *patchWidth + patchPixelX , value );
89+ }
8890 }
8991 }
9092 patchNum ++;
0 commit comments