Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
change WLS
  • Loading branch information
hhbyyh committed Sep 15, 2015
commit f8f2633f33aeac27a4e2ce5d5cbdb50968ee474b
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private[ml] class WeightedLeastSquares(
if (fitIntercept) {
// shift centers
// A^T A - aBar aBar^T
RowMatrix.dspr(-1.0, aBar, aaValues)
BLAS.spr(-1.0, aBar, aaValues)
// A^T b - bBar aBar
BLAS.axpy(-bBar, aBar, abBar)
}
Expand Down Expand Up @@ -203,7 +203,7 @@ private[ml] object WeightedLeastSquares {
bbSum += w * b * b
BLAS.axpy(w, a, aSum)
BLAS.axpy(w * b, a, abSum)
RowMatrix.dspr(w, a, aaSum.values)
BLAS.spr(w, a, aaSum)
this
}

Expand Down
79 changes: 44 additions & 35 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,41 +175,6 @@ private[spark] object BLAS extends Serializable with Logging {
sum
}

/**
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR.
*
* @param U the upper triangular part of the matrix packed in an array (column major)
*/
def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
val n = v.size
v match {
case DenseVector(values) =>
NativeBLAS.dspr("U", n, alpha, values, 1, U)
case SparseVector(size, indices, values) =>
val nnz = indices.length
var colStartIdx = 0
var prevCol = 0
var col = 0
var j = 0
var i = 0
var av = 0.0
while (j < nnz) {
col = indices(j)
// Skip empty columns.
colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
col = indices(j)
av = alpha * values(j)
i = 0
while (i <= j) {
U(colStartIdx + indices(i)) += av * values(i)
i += 1
}
j += 1
prevCol = col
}
}
}

/**
* y = x
*/
Expand Down Expand Up @@ -271,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging {
_nativeBLAS
}

/**
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
*
* @param U the upper triangular part of the matrix in a [[DenseVector]](column major)
*/
def spr(alpha: Double, v: Vector, U: DenseVector): Unit = {
spr(alpha, v, U.values)
}

/**
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
*
* @param U the upper triangular part of the matrix packed in an array (column major)
*/
def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
val n = v.size
v match {
case DenseVector(values) =>
NativeBLAS.dspr("U", n, alpha, values, 1, U)
case SparseVector(size, indices, values) =>
val nnz = indices.length
var colStartIdx = 0
var prevCol = 0
var col = 0
var j = 0
var i = 0
var av = 0.0
while (j < nnz) {
col = indices(j)
// Skip empty columns.
colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
col = indices(j)
av = alpha * values(j)
i = 0
while (i <= j) {
U(colStartIdx + indices(i)) += av * values(i)
i += 1
}
j += 1
prevCol = col
}
}
}

/**
* A := alpha * x * x^T^ + A
* @param alpha a real scalar that will be multiplied to x * x^T^.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class RowMatrix @Since("1.0.0") (
// Compute the upper triangular part of the gram matrix.
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
seqOp = (U, v) => {
BLAS.dspr(1.0, v, U.data)
BLAS.spr(1.0, v, U.data)
U
}, combOp = (U1, U2) => U1 += U2)

Expand Down Expand Up @@ -671,6 +671,7 @@ class RowMatrix @Since("1.0.0") (
@Since("1.0.0")
@Experimental
object RowMatrix {

/**
* Fills a full square matrix from its upper triangular part.
*/
Expand Down
25 changes: 25 additions & 0 deletions mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite {
}
}

test("spr") {
// test dense vector
val alpha = 0.1
val x = new DenseVector(Array(1.0, 2, 2.1, 4))
val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4))
val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6))

spr(alpha, x, U)
assert(U ~== expected absTol 1e-9)

val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5))
withClue("Size of vector must match the rank of matrix") {
intercept[Exception] {
spr(alpha, x, matrix33)
}
}

// test sparse vector
val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2))
val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4))
spr(0.1, sv, U2)
val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4))
assert(U2 ~== expectedSparse absTol 1e-15)
}

test("syr") {
val dA = new DenseMatrix(4, 4,
Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
Expand Down