Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,17 @@ def probability(self):
return self._call_java("probability")


class KMeansSummary(ClusteringSummary):
"""
.. note:: Experimental

Summary of KMeans.

.. versionadded:: 2.1.0
"""
pass


class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by KMeans.
Expand All @@ -316,6 +327,27 @@ def computeCost(self, dataset):
"""
return self._call_java("computeCost", dataset)

@property
@since("2.1.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model instance.
"""
return self._call_java("hasSummary")

@property
@since("2.1.0")
def summary(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you match the other implementations here as well?

"""
Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
return KMeansSummary(self._call_java("summary"))
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)


@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
Expand All @@ -341,6 +373,13 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True
>>> rows[2].prediction == rows[3].prediction
True
>>> model.hasSummary
True
>>> summary = model.summary
>>> summary.k
2
>>> summary.clusterSizes
[2, 2]
>>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
Expand All @@ -349,6 +388,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
>>> model_path = temp_path + "/kmeans_model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add

    >>> model.hasSummary
    True
    >>> summary = model.summary
    >>> summary.k
    2
    >>> summary.clusterSizes
    [2, 2]

to match other doctests

>>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path)
>>> model2.hasSummary
False
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,21 @@ def test_bisecting_kmeans_summary(self):
self.assertEqual(len(s.clusterSizes), 2)
self.assertEqual(s.k, 2)

def test_kmeans_summary(self):
data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
(Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
df = self.spark.createDataFrame(data, ["features"])
kmeans = KMeans(k=2, seed=1)
model = kmeans.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.featuresCol, "features")
self.assertEqual(s.predictionCol, "prediction")
self.assertTrue(isinstance(s.cluster, DataFrame))
self.assertEqual(len(s.clusterSizes), 2)
self.assertEqual(s.k, 2)


class OneVsRestTests(SparkSessionTestCase):

Expand Down