-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-8018][MLlib]KMeans should accept initial cluster centers as param #6737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
6959861
e9c35d7
16f1b53
cd5dc5c
3f5fc8e
582e6d9
60c8ce2
714acb5
242ead1
e721dfe
07f8554
d12336e
06d13ef
c446c58
ef95ee2
94b56df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,6 +156,20 @@ class KMeans private ( | |
| this | ||
| } | ||
|
|
||
| // Initial cluster centers can be provided as a KMeansModel object rather than using the | ||
| // random or k-means|| initializationMode | ||
| private var initialModel: Option[KMeansModel] = None | ||
|
|
||
| /** Set the initial starting point, bypassing the random initialization or k-means|| | ||
| * The condition (model.k == this.k) must be met; failure will result in an | ||
| * IllegalArgumentException. | ||
| */ | ||
| def setInitialModel(model: KMeansModel): this.type = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm just wondering out loud, don't know if this makes sense -- should the user have to supply a whole model just to specify initial centroids? or can they just specify the centroids here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @srowen Thanks for the comments. We followed this approach based on @mengxr 's suggestion at https://issues.apache.org/jira/browse/SPARK-8018?focusedCommentId=14571757&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-14571757. Please have a look.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, that's fine. |
||
| require(model.k==k, "mismatched cluster count") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing spaces around |
||
| initialModel = Some(model) | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Train a K-means model on the given set of points; `data` should be cached for high | ||
| * performance, because this is an iterative algorithm. | ||
|
|
@@ -193,12 +207,19 @@ class KMeans private ( | |
|
|
||
| val initStartTime = System.nanoTime() | ||
|
|
||
| val centers = if (initializationMode == KMeans.RANDOM) { | ||
| initRandom(data) | ||
| } else { | ||
| initKMeansParallel(data) | ||
| val centers = initialModel match { | ||
| case Some(kMeansCenters) => { | ||
| Array.tabulate(runs)(r => kMeansCenters.clusterCenters | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will run KMeans with the same center
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jkbradley Is it ok to explicitly set the value 1 in the place of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jkbradley Irrespective of the value of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not think setInitialModel should do anything other than set initialModel. When training/fitting happens, I think we should then examine initialModel. If initialModel is set, then we should act as though runs = 1 (but not actually change that value) and also print the warning message. The documentation should be added within this PR. |
||
| .map(s => new VectorWithNorm(s, Vectors.norm(s, 2.0)))) | ||
| } | ||
| case None => { | ||
| if (initializationMode == KMeans.RANDOM) { | ||
| initRandom(data) | ||
| } else { | ||
| initKMeansParallel(data) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 | ||
| logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + | ||
| " seconds.") | ||
|
|
@@ -478,6 +499,25 @@ object KMeans { | |
| train(data, k, maxIterations, runs, K_MEANS_PARALLEL) | ||
| } | ||
|
|
||
| /** | ||
| * Trains a k-means model using the given set of parameters and initial cluster centers | ||
| * | ||
| * @param data training points stored as `RDD[Vector]` | ||
| * @param k number of clusters | ||
| * @param maxIterations max number of iterations | ||
| * @param initialModel an existing set of cluster centers. | ||
| */ | ||
| def train( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure at this point what the thinking is on adding yet another overload to the utility method. At some point one is expected to use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. This extra static method is not necessary since we decided we prefer the builder pattern, as @srowen said. |
||
| data: RDD[Vector], | ||
| k: Int, | ||
| maxIterations: Int, | ||
| initialModel: KMeansModel): KMeansModel = { | ||
| new KMeans().setK(k) | ||
| .setMaxIterations(maxIterations) | ||
| .setInitialModel(initialModel) | ||
| .run(data) | ||
| } | ||
|
|
||
| /** | ||
| * Returns the index of the closest center to the given point, as well as the squared distance. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -278,6 +278,34 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| } | ||
| } | ||
| } | ||
|
|
||
| test("Initialize using given cluster centers") { | ||
| val points = Seq( | ||
| Vectors.dense(0.0, 0.0), | ||
| Vectors.dense(0.0, 0.1), | ||
| Vectors.dense(0.1, 0.0), | ||
| Vectors.dense(9.0, 0.0), | ||
| Vectors.dense(9.0, 0.2), | ||
| Vectors.dense(9.2, 0.0) | ||
| ) | ||
| val rdd = sc.parallelize(points, 3) | ||
| val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1) | ||
|
|
||
| val tempDir = Utils.createTempDir() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove temp dir after loading model back |
||
| val path = tempDir.toURI.toString | ||
| model.save(sc, path) | ||
| val loadedModel = KMeansModel.load(sc, path) | ||
|
|
||
| val newModel = KMeans.train(rdd, k = 2, maxIterations = 2, initialModel = loadedModel) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a somewhat brittle test. It requires that the initial training find the correct centers. I imagine it generally would, but with so few data, I suspect it will fail every now and then. A better test might be:
(Basically, cluster centers should not move, so they should definitely be different in the end.) |
||
| val predicts = newModel.predict(rdd).collect() | ||
|
|
||
| assert(predicts(0) === predicts(1)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please fix indentation |
||
| assert(predicts(0) === predicts(2)) | ||
| assert(predicts(3) === predicts(4)) | ||
| assert(predicts(3) === predicts(5)) | ||
| assert(predicts(0) != predicts(3)) | ||
| } | ||
|
|
||
| } | ||
|
|
||
| object KMeansSuite extends SparkFunSuite { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scala style: comment should begin on line after
/**(See other examples of multi-line comments in this file.)