Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Keep vectors sparse in Java when reading LabeledPoints
  • Loading branch information
mateiz committed Apr 15, 2014
commit 0e7a3d8599d6eb677e734cd3fadc27d6942a40f9
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ class PythonMLLibAPI extends Serializable {
dataBytesJRDD: JavaRDD[Array[Byte]],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes).toArray // TODO: deal with sparse vectors here!
LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), x.slice(1, x.size))
})
val initialWeights = deserializeDoubleVector(initialWeightsBA)
val model = trainFunc(data, initialWeights)
Expand Down Expand Up @@ -300,8 +300,8 @@ class PythonMLLibAPI extends Serializable {
dataBytesJRDD: JavaRDD[Array[Byte]],
lambda: Double): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes).toArray // TODO: make this efficient for sparse vecs
LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), x.slice(1, x.size))
})
val model = NaiveBayes.train(data, lambda)
val ret = new java.util.LinkedList[java.lang.Object]()
Expand Down
49 changes: 46 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ trait Vector extends Serializable {
* @param i index
*/
private[mllib] def apply(i: Int): Double = toBreeze(i)

private[mllib] def slice(start: Int, end: Int): Vector
}

/**
Expand Down Expand Up @@ -130,9 +132,11 @@ object Vectors {
private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = {
breezeVector match {
case v: BDV[Double] =>
require(v.offset == 0, s"Do not support non-zero offset ${v.offset}.")
require(v.stride == 1, s"Do not support stride other than 1, but got ${v.stride}.")
new DenseVector(v.data)
if (v.offset == 0 && v.stride == 1) {
new DenseVector(v.data)
} else {
new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one
}
case v: BSV[Double] =>
new SparseVector(v.length, v.index, v.data)
case v: BV[_] =>
Expand All @@ -155,6 +159,10 @@ class DenseVector(val values: Array[Double]) extends Vector {
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)

override def apply(i: Int) = values(i)

private[mllib] override def slice(start: Int, end: Int): Vector = {
new DenseVector(values.slice(start, end))
}
}

/**
Expand Down Expand Up @@ -185,4 +193,39 @@ class SparseVector(
}

private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)

override def apply(pos: Int): Double = {
// A more efficient apply() than creating a new Breeze vector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Breeze's sparse vector uses binary search for random access. I think in the current code base, only decision tree needs random access to a vector. However, we haven't claimed it supports sparse input yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll remove this and split() because they're no longer needed. They were needed when we passed vectors with the label included from Python instead of passing LabeledPoint.

var i = 0
while (i < indices.length) {
if (indices(i) == pos) {
return values(i)
} else if (indices(i) > pos) {
return 0.0
}
i += 1
}
0.0
}

private[mllib] override def slice(start: Int, end: Int): Vector = {
require(start <= end, s"invalid range: ${start} to ${end}")
require(start >= 0, s"invalid range: ${start} to ${end}")
require(end <= size, s"invalid range: ${start} to ${end}")
// Figure out the range of indices that fall within the given bounds
var i = 0
var indexRangeStart = 0
var indexRangeEnd = 0
while (i < indices.length && indices(i) < start) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use binary search instead. However, is it called somewhere? You already changed the way we serialize/deserialize vectors.

i += 1
}
indexRangeStart = i
while (i < indices.length && indices(i) < end) {
i += 1
}
indexRangeEnd = i
val newIndices = indices.slice(indexRangeStart, indexRangeEnd).map(_ - start)
val newValues = values.slice(indexRangeStart, indexRangeEnd)
new SparseVector(end - start, newIndices, newValues)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,42 @@ class VectorsSuite extends FunSuite {
assert(v.## != another.##)
}
}

test("indexing dense vectors") {
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
assert(vec(0) === 1.0)
assert(vec(3) === 4.0)
}

test("indexing sparse vectors") {
val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
assert(vec(0) === 1.0)
assert(vec(1) === 0.0)
assert(vec(2) === 2.0)
assert(vec(3) === 0.0)
assert(vec(6) === 4.0)
val vec2 = Vectors.sparse(8, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
assert(vec2(6) === 4.0)
assert(vec2(7) === 0.0)
}

test("slicing dense vectors") {
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
val slice = vec.slice(1, 3)
assert(slice === Vectors.dense(2.0, 3.0))
assert(slice.isInstanceOf[DenseVector], "slice was not DenseVector")
}

test("slicing sparse vectors") {
val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
val slice = vec.slice(1, 5)
assert(slice === Vectors.sparse(4, Array(1,3), Array(2.0, 3.0)))
assert(slice.isInstanceOf[SparseVector], "slice was not SparseVector")
val slice2 = vec.slice(1, 2)
assert(slice2 === Vectors.sparse(1, Array(), Array()))
assert(slice2.isInstanceOf[SparseVector], "slice was not SparseVector")
val slice3 = vec.slice(6, 7)
assert(slice3 === Vectors.sparse(1, Array(0), Array(4.0)))
assert(slice3.isInstanceOf[SparseVector], "slice was not SparseVector")
}
}