Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ class KMeans private (
this
}

// Initial cluster centers can be provided as a KMeansModel object rather than using the
// random or k-means|| initializationMode
private var initialModel: Option[KMeansModel] = None

/** Set the initial starting point, bypassing the random initialization or k-means||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala style: comment should begin on line after /** (See other examples of multi-line comments in this file.)

* The condition (model.k == this.k) must be met; failure will result in an
* IllegalArgumentException.
*/
def setInitialModel(model: KMeansModel): this.type = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering out loud, don't know if this makes sense -- should the user have to supply a whole model just to specify initial centroids? or can they just specify the centroids here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, that's fine.

require(model.k==k, "mismatched cluster count")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing spaces around == here. You might print the actual vs expected in the message.

initialModel = Some(model)
this
}

/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
Expand Down Expand Up @@ -193,12 +207,19 @@ class KMeans private (

val initStartTime = System.nanoTime()

val centers = if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
val centers = initialModel match {
case Some(kMeansCenters) => {
Array.tabulate(runs)(r => kMeansCenters.clusterCenters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will run KMeans with the same center runs times. It should only be run once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley Is it ok to explicitly set the value 1 in the place of runs here or should we add a require(runs==1, "can only run once") in setInitialModel() ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should:

  • Not change the value of runs, but ignore it when initialModel is given
  • Document this behavior in the Scala/Python docs for runs and initialModel
  • Add a logWarning which prints a warning message when runs is being ignored

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley Irrespective of the value of runs set by the user, do you mean to reassign runs to 1 from the setInitialModel() along with a warning message ?
To document this behavior , can we do it in this PR itself or should we start a new ticket for documentation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think setInitialModel should do anything other than set initialModel.

When training/fitting happens, I think we should then examine initialModel. If initialModel is set, then we should act as though runs = 1 (but not actually change that value) and also print the warning message.

The documentation should be added within this PR.

.map(s => new VectorWithNorm(s, Vectors.norm(s, 2.0))))
}
case None => {
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
}
}

val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")
Expand Down Expand Up @@ -478,6 +499,25 @@ object KMeans {
train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
}

/**
* Trains a k-means model using the given set of parameters and initial cluster centers
*
* @param data training points stored as `RDD[Vector]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param initialModel an existing set of cluster centers.
*/
def train(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure at this point what the thinking is on adding yet another overload to the utility method. At some point one is expected to use KMeans directly, and I recall some move to stop adding these utility methods. But I am not sure -- @mengxr @jkbradley any opinion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This extra static method is not necessary since we decided we prefer the builder pattern, as @srowen said.

data: RDD[Vector],
k: Int,
maxIterations: Int,
initialModel: KMeansModel): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setInitialModel(initialModel)
.run(data)
}

/**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,34 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}

test("Initialize using given cluster centers") {
val points = Seq(
Vectors.dense(0.0, 0.0),
Vectors.dense(0.0, 0.1),
Vectors.dense(0.1, 0.0),
Vectors.dense(9.0, 0.0),
Vectors.dense(9.0, 0.2),
Vectors.dense(9.2, 0.0)
)
val rdd = sc.parallelize(points, 3)
val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1)

val tempDir = Utils.createTempDir()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove temp dir after loading model back

val path = tempDir.toURI.toString
model.save(sc, path)
val loadedModel = KMeansModel.load(sc, path)

val newModel = KMeans.train(rdd, k = 2, maxIterations = 2, initialModel = loadedModel)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a somewhat brittle test. It requires that the initial training find the correct centers. I imagine it generally would, but with so few data, I suspect it will fail every now and then.

A better test might be:

  • have 4 data points A,B,C,D at the corners of a square
  • use k = 2
  • compare the results from starting with initial centers at A,C vs. B,D, using maxIterations = 1

(Basically, cluster centers should not move, so they should definitely be different in the end.)

val predicts = newModel.predict(rdd).collect()

assert(predicts(0) === predicts(1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix indentation

assert(predicts(0) === predicts(2))
assert(predicts(3) === predicts(4))
assert(predicts(3) === predicts(5))
assert(predicts(0) != predicts(3))
}

}

object KMeansSuite extends SparkFunSuite {
Expand Down