-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14894][PySpark] Add result summary api to Gaussian Mixture #12675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c42f5dd
7db9c0d
b19756c
7d16a23
3cf080a
c2b1aef
3bc75a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| from pyspark import since, keyword_only | ||
| from pyspark.ml.util import * | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper | ||
| from pyspark.ml.param.shared import * | ||
| from pyspark.ml.common import inherit_doc | ||
|
|
||
|
|
@@ -56,6 +56,125 @@ def gaussiansDF(self): | |
| """ | ||
| return self._call_java("gaussiansDF") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def summary(self): | ||
| """ | ||
|
||
| Gets summary of model on | ||
| training set. An exception is thrown if | ||
| `trainingSummary is None`. | ||
| """ | ||
| java_gmt_summary = self._call_java("summary") | ||
| return GaussianMixtureTrainingSummary(java_gmt_summary) | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def hasSummary(self): | ||
| """ | ||
| Indicates whether a training summary exists for this model | ||
| instance. | ||
| """ | ||
| return self._call_java("hasSummary") | ||
|
|
||
| @since("2.0.0") | ||
| def evaluate(self, dataset): | ||
|
||
| """ | ||
| Evaluates the model on a test dataset. | ||
|
|
||
| :param dataset: | ||
| Test dataset to evaluate model on, where dataset is an | ||
| instance of :py:class:`pyspark.sql.DataFrame` | ||
| """ | ||
| if not isinstance(dataset, DataFrame): | ||
| raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) | ||
| java_gmt_summary = self._call_java("evaluate", dataset) | ||
| return GaussianMixtureSummary(java_gmt_summary) | ||
|
|
||
|
|
||
| class GaussianMixtureSummary(JavaWrapper): | ||
| """ | ||
| Abstraction for Gaussian Mixture Results for a given model. | ||
|
|
||
| .. versionadded:: 2.0.0 | ||
| """ | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def predictions(self): | ||
| """ | ||
| Dataframe outputted by the model's `transform` method. | ||
| """ | ||
| return self._call_java("predictions") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def probabilityCol(self): | ||
| """ | ||
| Field in "predictions" which gives the probability | ||
| of each class. | ||
| """ | ||
| return self._call_java("probabilityCol") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def featuresCol(self): | ||
| """ | ||
| Field in "predictions" which gives the features of each instance. | ||
| """ | ||
| return self._call_java("featuresCol") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def cluster(self): | ||
| """ | ||
| Cluster centers of the transformed data. | ||
| """ | ||
| return self._call_java("cluster") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def probability(self): | ||
| """ | ||
| Probability of each cluster. | ||
| """ | ||
| return self._call_java("probability") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def clusterSizes(self): | ||
| """ | ||
| Size of (number of data points in) each cluster. | ||
| """ | ||
| return self._call_java("clusterSizes") | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class GaussianMixtureTrainingSummary(GaussianMixtureSummary): | ||
| """ | ||
| Abstraction for Gaussian Mixture Training results. | ||
| Currently, the training summary ignores the training weights except | ||
| for the objective trace. | ||
|
|
||
| .. versionadded:: 2.0.0 | ||
| """ | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def objectiveHistory(self): | ||
| """ | ||
| Objective function (scaled loss + regularization) at each | ||
| iteration. | ||
| """ | ||
| return self._call_java("objectiveHistory") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def totalIterations(self): | ||
| """ | ||
| Number of training iterations until termination. | ||
| """ | ||
| return self._call_java("totalIterations") | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we've shipped 2.0 we will need to update this to 2.1 (same with the versionAdded notes)