Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/mllib-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on
a given dataset, the algorithm returns the best clustering result).
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed.

**Examples**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ class KMeans private (
this
}

// Initial cluster centers can be provided as a KMeansModel object rather than using the
// random or k-means|| initializationMode
private var initialModel: Option[KMeansModel] = None

/**
* Set the initial starting point, bypassing the random initialization or k-means||
* The condition model.k == this.k must be met, failure results
* in an IllegalArgumentException.
*/
def setInitialModel(model: KMeansModel): this.type = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering out loud, don't know if this makes sense -- should the user have to supply a whole model just to specify initial centroids? or can they just specify the centroids here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, that's fine.

require(model.k == k, "mismatched cluster count")
initialModel = Some(model)
this
}

/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
Expand Down Expand Up @@ -193,20 +208,34 @@ class KMeans private (

val initStartTime = System.nanoTime()

val centers = if (initializationMode == KMeans.RANDOM) {
initRandom(data)
// Only one run is allowed when initialModel is given
val numRuns = if (initialModel.nonEmpty) {
if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
1
} else {
initKMeansParallel(data)
runs
}

val centers = initialModel match {
case Some(kMeansCenters) => {
Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
}
case None => {
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")

val active = Array.fill(runs)(true)
val costs = Array.fill(runs)(0.0)
val active = Array.fill(numRuns)(true)
val costs = Array.fill(numRuns)(0.0)

var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
var iteration = 0

val iterationStartTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}

test("Initialize using given cluster centers") {
val points = Seq(
Vectors.dense(0.0, 0.0),
Vectors.dense(1.0, 0.0),
Vectors.dense(0.0, 1.0),
Vectors.dense(1.0, 1.0)
)
val rdd = sc.parallelize(points, 3)
// creating an initial model
val initialModel = new KMeansModel(Array(points(0), points(2)))

val returnModel = new KMeans()
.setK(2)
.setMaxIterations(0)
.setInitialModel(initialModel)
.run(rdd)
// comparing the returned model and the initial model
assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
}

}

object KMeansSuite extends SparkFunSuite {
Expand Down