Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 60 additions & 39 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,21 @@ private[spark] object BLAS extends Serializable with Logging {
* y += a * x
*/
private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = {
val nnz = x.indices.size
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val nnz = xIndices.size

if (a == 1.0) {
var k = 0
while (k < nnz) {
y.values(x.indices(k)) += x.values(k)
yValues(xIndices(k)) += xValues(k)
k += 1
}
} else {
var k = 0
while (k < nnz) {
y.values(x.indices(k)) += a * x.values(k)
yValues(xIndices(k)) += a * xValues(k)
k += 1
}
}
Expand Down Expand Up @@ -119,11 +123,15 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
private def dot(x: SparseVector, y: DenseVector): Double = {
val nnz = x.indices.size
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val nnz = xIndices.size

var sum = 0.0
var k = 0
while (k < nnz) {
sum += x.values(k) * y.values(x.indices(k))
sum += xValues(k) * yValues(xIndices(k))
k += 1
}
sum
Expand All @@ -133,19 +141,24 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
private def dot(x: SparseVector, y: SparseVector): Double = {
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val yIndices = y.indices
val nnzx = xIndices.size
val nnzy = yIndices.size

var kx = 0
val nnzx = x.indices.size
var ky = 0
val nnzy = y.indices.size
var sum = 0.0
// y catching x
while (kx < nnzx && ky < nnzy) {
val ix = x.indices(kx)
while (ky < nnzy && y.indices(ky) < ix) {
val ix = xIndices(kx)
while (ky < nnzy && yIndices(ky) < ix) {
ky += 1
}
if (ky < nnzy && y.indices(ky) == ix) {
sum += x.values(kx) * y.values(ky)
if (ky < nnzy && yIndices(ky) == ix) {
sum += xValues(kx) * yValues(ky)
ky += 1
}
kx += 1
Expand All @@ -163,21 +176,25 @@ private[spark] object BLAS extends Serializable with Logging {
case dy: DenseVector =>
x match {
case sx: SparseVector =>
val sxIndices = sx.indices
val sxValues = sx.values
val dyValues = dy.values
val nnz = sxIndices.size

var i = 0
var k = 0
val nnz = sx.indices.size
while (k < nnz) {
val j = sx.indices(k)
val j = sxIndices(k)
while (i < j) {
dy.values(i) = 0.0
dyValues(i) = 0.0
i += 1
}
dy.values(i) = sx.values(k)
dyValues(i) = sxValues(k)
i += 1
k += 1
}
while (i < n) {
dy.values(i) = 0.0
dyValues(i) = 0.0
i += 1
}
case dx: DenseVector =>
Expand Down Expand Up @@ -311,6 +328,8 @@ private[spark] object BLAS extends Serializable with Logging {
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")

val Avals = A.values
val Bvals = B.values
val Cvals = C.values
val Arows = if (!transA) A.rowIndices else A.colPtrs
val Acols = if (!transA) A.colPtrs else A.rowIndices

Expand All @@ -327,11 +346,11 @@ private[spark] object BLAS extends Serializable with Logging {
val indEnd = Arows(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
sum += Avals(i) * B.values(Bstart + Acols(i))
sum += Avals(i) * Bvals(Bstart + Acols(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounterForA += 1
}
colCounterForB += 1
Expand All @@ -349,15 +368,15 @@ private[spark] object BLAS extends Serializable with Logging {
i += 1
}
val Cindex = Cstart + rowCounter
C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounter += 1
}
colCounterForB += 1
}
}
} else {
// Scale matrix first if `beta` is not equal to 0.0
if (beta != 0.0){
if (beta != 0.0) {
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
}
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
Expand All @@ -371,9 +390,9 @@ private[spark] object BLAS extends Serializable with Logging {
while (colCounterForA < kA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B.values(Bstart + colCounterForA) * alpha
while (i < indEnd){
C.values(Cstart + Arows(i)) += Avals(i) * Bval
val Bval = Bvals(Bstart + colCounterForA) * alpha
while (i < indEnd) {
Cvals(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
Expand All @@ -384,12 +403,12 @@ private[spark] object BLAS extends Serializable with Logging {
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Cstart = colCounterForB * mA
while (colCounterForA < kA){
while (colCounterForA < kA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B(colCounterForB, colCounterForA) * alpha
while (i < indEnd){
C.values(Cstart + Arows(i)) += Avals(i) * Bval
while (i < indEnd) {
Cvals(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
Expand Down Expand Up @@ -484,41 +503,43 @@ private[spark] object BLAS extends Serializable with Logging {
beta: Double,
y: DenseVector): Unit = {

val mA: Int = if(!trans) A.numRows else A.numCols
val nA: Int = if(!trans) A.numCols else A.numRows
val xValues = x.values
val yValues = y.values

val mA: Int = if (!trans) A.numRows else A.numCols
val nA: Int = if (!trans) A.numCols else A.numRows

val Avals = A.values
val Arows = if (!trans) A.rowIndices else A.colPtrs
val Acols = if (!trans) A.colPtrs else A.rowIndices

// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
if (trans){
if (trans) {
var rowCounter = 0
while (rowCounter < mA){
while (rowCounter < mA) {
var i = Arows(rowCounter)
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
while(i < indEnd){
sum += Avals(i) * x.values(Acols(i))
while (i < indEnd) {
sum += Avals(i) * xValues(Acols(i))
i += 1
}
y.values(rowCounter) = beta * y.values(rowCounter) + sum * alpha
yValues(rowCounter) = beta * yValues(rowCounter) + sum * alpha
rowCounter += 1
}
} else {
// Scale vector first if `beta` is not equal to 0.0
if (beta != 0.0){
if (beta != 0.0) {
scal(beta, y)
}
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
while (colCounterForA < nA){
while (colCounterForA < nA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val xVal = x.values(colCounterForA) * alpha
while (i < indEnd){
val xVal = xValues(colCounterForA) * alpha
while (i < indEnd) {
val rowIndex = Arows(i)
y.values(rowIndex) += Avals(i) * xVal
yValues(rowIndex) += Avals(i) * xVal
i += 1
}
colCounterForA += 1
Expand Down