Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Changes to support KMeans with large feature space
  • Loading branch information
levin-royl committed Jan 13, 2016
commit 33d760c7d848da66d8a84523f11a7fc38ff1afc4
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
Expand All @@ -45,7 +45,9 @@ class KMeans private (
private var initializationMode: String,
private var initializationSteps: Int,
private var epsilon: Double,
private var seed: Long) extends Serializable with Logging {
private var seed: Long,
private var vectorFactory: VectorFactory = DenseVectorFactory.instance
) extends Serializable with Logging {

/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
Expand Down Expand Up @@ -176,6 +178,13 @@ class KMeans private (
this
}

def getVectorFactory: VectorFactory = vectorFactory

def setVectorFactory(vectorFactory: VectorFactory): this.type = {
this.vectorFactory = vectorFactory
this
}

// Initial cluster centers can be provided as a KMeansModel object rather than using the
// random or k-means|| initializationMode
private var initialModel: Option[KMeansModel] = None
Expand Down Expand Up @@ -282,7 +291,8 @@ class KMeans private (
val k = thisActiveCenters(0).length
val dims = thisActiveCenters(0)(0).vector.size

val sums = Array.fill(runs, k)(Vectors.zeros(dims))
// val sums = Array.fill(runs, k)(Vectors.zeros(dims))
val sums = Array.fill(runs, k)(vectorFactory.zeros(dims))
val counts = Array.fill(runs, k)(0L)

points.foreach { point =>
Expand Down Expand Up @@ -376,7 +386,8 @@ class KMeans private (
// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
// val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).compact(vectorFactory)))

/** Merges new centers to centers. */
def mergeNewCenters(): Unit = {
Expand Down Expand Up @@ -436,7 +447,8 @@ class KMeans private (
}.collect()
mergeNewCenters()
chosen.foreach { case (p, rs) =>
rs.foreach(newCenters(_) += p.toDense)
// rs.foreach(newCenters(_) += p.toDense)
rs.foreach(newCenters(_) += p)
}
step += 1
}
Expand All @@ -459,7 +471,7 @@ class KMeans private (
val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30, vectorFactory)
}

finalCenters.toArray
Expand Down Expand Up @@ -488,6 +500,7 @@ object KMeans {
* @param runs number of parallel runs, defaults to 1. The best model is returned.
* @param initializationMode initialization model, either "random" or "k-means||" (default).
* @param seed random seed value for cluster initialization
* @param vectorFactory provide factory to use for creating the vectors representing the centroids
*/
@Since("1.3.0")
def train(
Expand All @@ -496,12 +509,14 @@ object KMeans {
maxIterations: Int,
runs: Int,
initializationMode: String,
seed: Long): KMeansModel = {
seed: Long,
vectorFactory: VectorFactory = DenseVectorFactory.instance): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
.setSeed(seed)
.setVectorFactory(vectorFactory)
.run(data)
}

Expand Down Expand Up @@ -617,5 +632,33 @@ class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable
def this(array: Array[Double]) = this(Vectors.dense(array))

/** Converts the vector to a dense vector. */
def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
// def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)

def compact(fact: VectorFactory): VectorWithNorm = new VectorWithNorm(fact.compact(vector), norm)
}

trait VectorFactory extends Serializable {
def zeros(size: Int): Vector

def compact(vec: Vector): Vector
}

class DenseVectorFactory private() extends VectorFactory {
override def zeros(size: Int): Vector = Vectors.zeros(size)

override def compact(vec: Vector): Vector = vec.toDense
}

object DenseVectorFactory {
val instance = new DenseVectorFactory
}

class SmartVectorFactory private() extends VectorFactory {
override def zeros(size: Int): Vector = new SparseVector(size, Array.empty, Array.empty)

override def compact(vec: Vector): Vector = vec.compressed
}

object SmartVectorFactory {
val instance = new SmartVectorFactory
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ private[mllib] object LocalKMeans extends Logging {
points: Array[VectorWithNorm],
weights: Array[Double],
k: Int,
maxIterations: Int
maxIterations: Int,
vectorFactory: VectorFactory
): Array[VectorWithNorm] = {
val rand = new Random(seed)
val dimensions = points(0).vector.size
val centers = new Array[VectorWithNorm](k)

// Initialize centers by sampling using the k-means++ procedure.
centers(0) = pickWeighted(rand, points, weights).toDense
centers(0) = pickWeighted(rand, points, weights).compact(vectorFactory)
for (i <- 1 until k) {
// Pick the next center with a probability proportional to cost under current centers
val curCenters = centers.view.take(i)
Expand All @@ -62,9 +63,9 @@ private[mllib] object LocalKMeans extends Logging {
if (j == 0) {
logWarning("kMeansPlusPlus initialization ran out of distinct points for centers." +
s" Using duplicate point for center k = $i.")
centers(i) = points(0).toDense
centers(i) = points(0).compact(vectorFactory)
} else {
centers(i) = points(j - 1).toDense
centers(i) = points(j - 1).compact(vectorFactory)
}
}

Expand Down Expand Up @@ -93,7 +94,7 @@ private[mllib] object LocalKMeans extends Logging {
while (j < k) {
if (counts(j) == 0.0) {
// Assign center to a random point
centers(j) = points(rand.nextInt(points.length)).toDense
centers(j) = points(rand.nextInt(points.length)).compact(vectorFactory)
} else {
scal(1.0 / counts(j), sums(j))
centers(j) = new VectorWithNorm(sums(j))
Expand Down
74 changes: 74 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,42 @@ private[spark] object BLAS extends Serializable with Logging {
_f2jBLAS
}

// equivalent to xIndices.union(yIndices).distinct.sortBy(i => i)
def unitedIndices(xSortedIndices: Array[Int], ySortedIndices: Array[Int]): Array[Int] = {
val arr = new Array[Int](xSortedIndices.length + ySortedIndices.length)
(0 until arr.length).foreach(i => arr(i) = -1)

var xj = 0
var yj = 0
var j = 0
var previ = Int.MaxValue

def getAt(arr: Array[Int], j: Int): Int = if (j < arr.length) arr(j) else Int.MaxValue

while (xj < xSortedIndices.length || yj < ySortedIndices.length) {
val xi = getAt(xSortedIndices, xj)
val yi = getAt(ySortedIndices, yj)

val i = if (xi <= yi) {
xj += 1
xi
}
else {
yj += 1
yi
}

if (previ != i) {
arr(j) = i
j += 1
}

previ = i
}

arr.filter(_ != -1)
}

/**
* y += a * x
*/
Expand All @@ -54,6 +90,16 @@ private[spark] object BLAS extends Serializable with Logging {
throw new UnsupportedOperationException(
s"axpy doesn't support x type ${x.getClass}.")
}
case sy: SparseVector =>
x match {
case sx: SparseVector =>
axpy(a, sx, sy)
case dx: DenseVector =>
axpy(a, dx, sy)
case _ =>
throw new UnsupportedOperationException(
s"axpy doesn't support x type ${x.getClass}.")
}
case _ =>
throw new IllegalArgumentException(
s"axpy only supports adding to a dense vector but got type ${y.getClass}.")
Expand Down Expand Up @@ -92,6 +138,34 @@ private[spark] object BLAS extends Serializable with Logging {
}
}

/**
* y += a * x
*/
private def axpy(a: Double, x: DenseVector, y: SparseVector): Unit = {
require(x.size == y.size)

val xIndices = (0 until x.size).filter(i => x(i) != 0.0).toArray
val xValues = xIndices.map(i => x(i))

axpy(a, Vectors.sparse(x.size, xIndices, xValues), y)
}

/**
* y += a * x
*/
private def axpy(a: Double, x: SparseVector, y: SparseVector): Unit = {
require(x.size == y.size)

val xIndices = x.indices
val yIndices = y.indices

val newIndices = unitedIndices(xIndices, yIndices)
assert(newIndices.size >= yIndices.size)

val newValues = newIndices.map(i => a*x(i) + y(i))
y.reassign(newIndices, newValues)
}

/** Y += a * x */
private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
Expand Down
34 changes: 25 additions & 9 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -700,21 +700,37 @@ object DenseVector {
* A sparse vector represented by an index array and an value array.
*
* @param size size of the vector.
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
* @param sortedIndices index array, assume to be strictly increasing.
* @param sortedValues value array, must have the same length as the index array.
*/
@Since("1.0.0")
@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector @Since("1.0.0") (
@Since("1.0.0") override val size: Int,
@Since("1.0.0") val indices: Array[Int],
@Since("1.0.0") val values: Array[Double]) extends Vector {
@Since("1.0.0") private var sortedIndices: Array[Int],
@Since("1.0.0") private var sortedValues: Array[Double]) extends Vector {

require(allRequirements())

def allRequirements(): Boolean = {
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
s" ${values.length} values.")
require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
s"which exceeds the specified vector size ${size}.")

true
}

def reassign(newSortedIndices: Array[Int], newValues: Array[Double]): Unit = {
sortedIndices = newSortedIndices
sortedValues = newValues
require(allRequirements())
}

def indices: Array[Int] = sortedIndices

require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
s" ${values.length} values.")
require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
s"which exceeds the specified vector size ${size}.")
def values: Array[Double] = sortedValues

override def toString: String =
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
Expand Down
Loading