Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import spark.{SparkContext, RDD}
import spark.SparkContext._
import spark.Logging
import spark.mllib.util.MLUtils
import spark.mllib.math.vector.{Vector, DenseVector}

import org.jblas.DoubleMatrix

Expand All @@ -45,7 +46,7 @@ class KMeans private (
var epsilon: Double)
extends Serializable with Logging
{
private type ClusterCenters = Array[Array[Double]]
private type ClusterCenters = Array[Vector]

def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)

Expand Down Expand Up @@ -112,7 +113,7 @@ class KMeans private (
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def train(data: RDD[Array[Double]]): KMeansModel = {
def train(data: RDD[Vector]): KMeansModel = {
// TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable

val sc = data.sparkContext
Expand All @@ -131,9 +132,9 @@ class KMeans private (

// Execute iterations of Lloyd's algorithm until all runs have converged
while (iteration < maxIterations && !activeRuns.isEmpty) {
type WeightedPoint = (DoubleMatrix, Long)
type WeightedPoint = (Vector, Long)
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
(p1._1.addi(p2._1), p1._2 + p2._2)
(p1._1 += p2._1, p1._2 + p2._2)
}

val activeCenters = activeRuns.map(r => centers(r)).toArray
Expand All @@ -143,15 +144,15 @@ class KMeans private (
val totalContribs = data.mapPartitions { points =>
val runs = activeCenters.length
val k = activeCenters(0).length
val dims = activeCenters(0)(0).length
val dims = activeCenters(0)(0).dimension

val sums = Array.fill(runs, k)(new DoubleMatrix(dims))
val sums = Array.fill(runs, k)(activeCenters(0)(0).like())
val counts = Array.fill(runs, k)(0L)

for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) {
val (bestCenter, cost) = KMeans.findClosest(centers, point)
costAccums(runIndex) += cost
sums(runIndex)(bestCenter).addi(new DoubleMatrix(point))
sums(runIndex)(bestCenter) += point
counts(runIndex)(bestCenter) += 1
}

Expand All @@ -167,8 +168,8 @@ class KMeans private (
for (j <- 0 until k) {
val (sum, count) = totalContribs((i, j))
if (count != 0) {
val newCenter = sum.divi(count).data
if (MLUtils.squaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
val newCenter = sum / count
if (newCenter.distanceSquared(centers(run)(j)) > epsilon * epsilon) {
changed = true
}
centers(run)(j) = newCenter
Expand All @@ -192,7 +193,7 @@ class KMeans private (
/**
* Initialize `runs` sets of cluster centers at random.
*/
private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
private def initRandom(data: RDD[Vector]): Array[ClusterCenters] = {
// Sample all the cluster centers in one pass to avoid repeated scans
val sample = data.takeSample(true, runs * k, new Random().nextInt())
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k))
Expand All @@ -207,7 +208,7 @@ class KMeans private (
*
* The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
*/
private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
private def initKMeansParallel(data: RDD[Vector]): Array[ClusterCenters] = {
// Initialize each run's center to a random point
val seed = new Random().nextInt()
val sample = data.takeSample(true, runs, seed)
Expand Down Expand Up @@ -260,7 +261,7 @@ object KMeans {
val K_MEANS_PARALLEL = "k-means||"

def train(
data: RDD[Array[Double]],
data: RDD[Vector],
k: Int,
maxIterations: Int,
runs: Int,
Expand All @@ -274,24 +275,24 @@ object KMeans {
.train(data)
}

def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = {
def train(data: RDD[Vector], k: Int, maxIterations: Int, runs: Int): KMeansModel = {
train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
}

def train(data: RDD[Array[Double]], k: Int, maxIterations: Int): KMeansModel = {
def train(data: RDD[Vector], k: Int, maxIterations: Int): KMeansModel = {
train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
}

/**
* Return the index of the closest point in `centers` to `point`, as well as its distance.
*/
private[mllib] def findClosest(centers: Array[Array[Double]], point: Array[Double])
private[mllib] def findClosest(centers: Array[Vector], point: Vector)
: (Int, Double) =
{
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
for (i <- 0 until centers.length) {
val distance = MLUtils.squaredDistance(point, centers(i))
val distance = point.distanceSquared(centers(i))
if (distance < bestDistance) {
bestDistance = distance
bestIndex = i
Expand All @@ -303,10 +304,10 @@ object KMeans {
/**
* Return the K-means cost of a given point against the given cluster centers.
*/
private[mllib] def pointCost(centers: Array[Array[Double]], point: Array[Double]): Double = {
private[mllib] def pointCost(centers: Array[Vector], point: Vector): Double = {
var bestDistance = Double.PositiveInfinity
for (i <- 0 until centers.length) {
val distance = MLUtils.squaredDistance(point, centers(i))
val distance = point.distanceSquared(centers(i))
if (distance < bestDistance) {
bestDistance = distance
}
Expand All @@ -321,12 +322,12 @@ object KMeans {
}
val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt)
val sc = new SparkContext(master, "KMeans")
val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble))
val data = sc.textFile(inputFile).map(line => new DenseVector(line.split(' ').map(_.toDouble)).asInstanceOf[Vector])
val model = KMeans.train(data, k, iters)
val cost = model.computeCost(data)
println("Cluster centers:")
for (c <- model.clusterCenters) {
println(" " + c.mkString(" "))
println(" " + c)
}
println("Cost: " + cost)
System.exit(0)
Expand Down
7 changes: 4 additions & 3 deletions mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@ package spark.mllib.clustering
import spark.RDD
import spark.SparkContext._
import spark.mllib.util.MLUtils
import spark.mllib.math.vector.{Vector, DenseVector}


/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
*/
class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable {
class KMeansModel(val clusterCenters: Array[Vector]) extends Serializable {
/** Total number of clusters. */
def k: Int = clusterCenters.length

/** Return the cluster index that a given point belongs to. */
def predict(point: Array[Double]): Int = {
def predict(point: Vector): Int = {
KMeans.findClosest(clusterCenters, point)._1
}

/**
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
* model on the given data.
*/
def computeCost(data: RDD[Array[Double]]): Double = {
def computeCost(data: RDD[Vector]): Double = {
data.map(p => KMeans.pointCost(clusterCenters, p)).sum
}
}
14 changes: 8 additions & 6 deletions mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package spark.mllib.clustering

import scala.util.Random

import spark.mllib.math.vector.{Vector, DenseVector, RandomAccessSparseVector}

import org.jblas.{DoubleMatrix, SimpleBlas}

/**
Expand All @@ -32,15 +34,15 @@ private[mllib] object LocalKMeans {
*/
def kMeansPlusPlus(
seed: Int,
points: Array[Array[Double]],
points: Array[Vector],
weights: Array[Double],
k: Int,
maxIterations: Int)
: Array[Array[Double]] =
: Array[Vector] =
{
val rand = new Random(seed)
val dimensions = points(0).length
val centers = new Array[Array[Double]](k)
val dimensions = points(0).dimension
val centers = new Array[Vector](k)

// Initialize centers by sampling using the k-means++ procedure
centers(0) = pickWeighted(rand, points, weights)
Expand Down Expand Up @@ -70,7 +72,7 @@ private[mllib] object LocalKMeans {
val counts = Array.fill(k)(0.0)
for ((p, i) <- points.zipWithIndex) {
val index = KMeans.findClosest(centers, p)._1
SimpleBlas.axpy(weights(i), new DoubleMatrix(p), sums(index))
SimpleBlas.axpy(weights(i), new DoubleMatrix(p.toArray), sums(index))
counts(index) += weights(i)
if (index != oldClosest(i)) {
moved = true
Expand All @@ -83,7 +85,7 @@ private[mllib] object LocalKMeans {
// Assign center to a random point
centers(i) = points(rand.nextInt(points.length))
} else {
centers(i) = sums(i).divi(counts(i)).data
centers(i) = points(0).like(sums(i).toArray()) / counts(i)
}
}
iteration += 1
Expand Down
Loading