Skip to content

Commit b62a67c

Browse files
committed
[SYSTEMML-850] Cache-conscious sparse-dense wcemm block operations
1 parent d6990dc commit b62a67c

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2943,19 +2943,35 @@ private static void matrixMultWCeMMSparseDense(MatrixBlock mW, MatrixBlock mU, M
29432943
SparseBlock w = mW.sparseBlock;
29442944
double[] u = mU.denseBlock;
29452945
double[] v = mV.denseBlock;
2946+
final int n = mW.clen;
29462947
final int cd = mU.clen;
29472948
double wceval = 0;
29482949

2949-
// approach: iterate over all cells of X and
2950-
for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) {
2951-
if( !w.isEmpty(i) ) {
2952-
int wpos = w.pos(i);
2953-
int wlen = w.size(i);
2954-
int[] wix = w.indexes(i);
2955-
double[] wval = w.values(i);
2956-
for( int k=wpos; k<wpos+wlen; k++ ) {
2957-
double uvij = dotProduct(u, v, uix, wix[k]*cd, cd);
2958-
wceval += wval[k] * FastMath.log(uvij + eps);
2950+
// approach: iterate over W, point-wise in order to exploit sparsity
2951+
// blocked over ij, while maintaining front of column indexes, where the
2952+
// blocksize is chosen such that we reuse each vector on average 8 times.
2953+
final int blocksizeIJ = (int) (8L*mW.rlen*mW.clen/mW.nonZeros);
2954+
int[] curk = new int[blocksizeIJ];
2955+
2956+
for( int bi=rl; bi<ru; bi+=blocksizeIJ ) {
2957+
int bimin = Math.min(ru, bi+blocksizeIJ);
2958+
//prepare starting indexes for block row
2959+
Arrays.fill(curk, 0);
2960+
//blocked execution over column blocks
2961+
for( int bj=0; bj<n; bj+=blocksizeIJ ) {
2962+
int bjmin = Math.min(n, bj+blocksizeIJ);
2963+
for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd ) {
2964+
if( w.isEmpty(i) ) continue;
2965+
int wpos = w.pos(i);
2966+
int wlen = w.size(i);
2967+
int[] wix = w.indexes(i);
2968+
double[] wval = w.values(i);
2969+
int k = wpos + curk[i-bi];
2970+
for( ; k<wpos+wlen && wix[k]<bjmin; k++ ) {
2971+
double uvij = dotProduct(u, v, uix, wix[k]*cd, cd);
2972+
wceval += wval[k] * FastMath.log(uvij + eps);
2973+
}
2974+
curk[i-bi] = k - wpos;
29592975
}
29602976
}
29612977
}

0 commit comments

Comments
 (0)