Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-14894][PySpark] Add result summary api to Gaussian Mixture
  • Loading branch information
GayathriMurali committed Jun 18, 2016
commit c42f5dd69c5ee073e54c3f79e3e7eddbe8ec88c4
121 changes: 120 additions & 1 deletion python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pyspark import since, keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc

Expand Down Expand Up @@ -56,6 +56,125 @@ def gaussiansDF(self):
"""
return self._call_java("gaussiansDF")

@property
@since("2.0.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we've shipped 2.0 we will need to update this to 2.1 (same with the versionAdded notes)

def summary(self):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PyDoc seems oddly formatted in terms of line breaks.

Gets summary of model on
training set. An exception is thrown if
`trainingSummary is None`.
"""
java_gmt_summary = self._call_java("summary")
return GaussianMixtureTrainingSummary(java_gmt_summary)

@property
@since("2.0.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")

@since("2.0.0")
def evaluate(self, dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function doesn't appear to be present in the Scala GaussianMixtureModel - can you double check your functions against the Scaladoc to make sure your wrapping the correct methods?

"""
Evaluates the model on a test dataset.

:param dataset:
Test dataset to evaluate model on, where dataset is an
instance of :py:class:`pyspark.sql.DataFrame`
"""
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_gmt_summary = self._call_java("evaluate", dataset)
return GaussianMixtureSummary(java_gmt_summary)


class GaussianMixtureSummary(JavaWrapper):
"""
Abstraction for Gaussian Mixture Results for a given model.

.. versionadded:: 2.0.0
"""

@property
@since("2.0.0")
def predictions(self):
"""
Dataframe outputted by the model's `transform` method.
"""
return self._call_java("predictions")

@property
@since("2.0.0")
def probabilityCol(self):
"""
Field in "predictions" which gives the probability
of each class.
"""
return self._call_java("probabilityCol")

@property
@since("2.0.0")
def featuresCol(self):
"""
Field in "predictions" which gives the features of each instance.
"""
return self._call_java("featuresCol")

@property
@since("2.0.0")
def cluster(self):
"""
Cluster centers of the transformed data.
"""
return self._call_java("cluster")

@property
@since("2.0.0")
def probability(self):
"""
Probability of each cluster.
"""
return self._call_java("probability")

@property
@since("2.0.0")
def clusterSizes(self):
"""
Size of (number of data points in) each cluster.
"""
return self._call_java("clusterSizes")


@inherit_doc
class GaussianMixtureTrainingSummary(GaussianMixtureSummary):
"""
Abstraction for Gaussian Mixture Training results.
Currently, the training summary ignores the training weights except
for the objective trace.

.. versionadded:: 2.0.0
"""

@property
@since("2.0.0")
def objectiveHistory(self):
"""
Objective function (scaled loss + regularization) at each
iteration.
"""
return self._call_java("objectiveHistory")

@property
@since("2.0.0")
def totalIterations(self):
"""
Number of training iterations until termination.
"""
return self._call_java("totalIterations")


@inherit_doc
class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
Expand Down
52 changes: 52 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,58 @@ def test_logistic_regression_summary(self):
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_gaussian_mixture_summary(self):
from pyspark.mllib.linalg import Vectors
sqlContext = SQLContext(self.sc)
df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], []))],
["features"])
gm = GaussianMixture(k=3, tol=0.0001, maxIter=10, seed=10)
model = gm.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary
# test that api is callable and returns expected types
self.assertGreater(s.totalIterations, 0)
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.predictionCol, "prediction")
self.assertEqual(s.featuresCol, "features")
objHist = s.objectiveHistory
cluster_sizes = s.clusterSizes
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
self.assertTrue(isinstance(s.cluster, DataFrame))
self.assertTrue(isinstance(s.probability, DataFrame))
self.assertEqual(isinstance(cluster_sizes[0], long))
# test evaluation (with a training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.cluster, s.cluster)

def test_gaussian_mixture_summary(self):
from pyspark.mllib.linalg import Vectors
sqlContext = SQLContext(self.sc)
df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], []))],
["features"])
gm = GaussianMixture(k=3, tol=0.0001, maxIter=10, seed=10)
model = gm.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary
# test that api is callable and returns expected types
self.assertGreater(s.totalIterations, 0)
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.predictionCol, "prediction")
self.assertEqual(s.featuresCol, "features")
objHist = s.objectiveHistory
cluster_sizes = s.clusterSizes
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
self.assertTrue(isinstance(s.cluster, DataFrame))
self.assertTrue(isinstance(s.probability, DataFrame))
self.assertEqual(isinstance(cluster_sizes[0], long))
# test evaluation (with a training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.cluster, s.cluster)


class OneVsRestTests(SparkSessionTestCase):

Expand Down