-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21741][ML][PySpark] Python API for DataFrame-based multivariate summarizer #20695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
488d45a
7d3cb1b
001ff46
b3e9ddd
e64f795
21edbcd
f7cec51
20968c1
b91dbeb
9a4a0ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,9 @@ | |
|
|
||
| from pyspark import since, SparkContext | ||
| from pyspark.ml.common import _java2py, _py2java | ||
| from pyspark.ml.wrapper import _jvm | ||
| from pyspark.ml.wrapper import JavaWrapper, _jvm | ||
| from pyspark.sql.column import Column, _to_seq | ||
| from pyspark.sql.functions import lit | ||
|
|
||
|
|
||
| class ChiSquareTest(object): | ||
|
|
@@ -195,6 +197,185 @@ def test(dataset, sampleCol, distName, *params): | |
| _jvm().PythonUtils.toSeq(params))) | ||
|
|
||
|
|
||
| class Summarizer(object): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| Tools for vectorized statistics on MLlib Vectors. | ||
| The methods in this package provide various statistics for Vectors contained inside DataFrames. | ||
| This class lets users pick the statistics they would like to extract for a given column. | ||
|
|
||
| >>> from pyspark.ml.stat import Summarizer | ||
| >>> from pyspark.sql import Row | ||
| >>> from pyspark.ml.linalg import Vectors | ||
| >>> summarizer = Summarizer.metrics("mean", "count") | ||
| >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), | ||
| ... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() | ||
| >>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) | ||
| +-----------------------------------+ | ||
| |aggregate_metrics(features, weight)| | ||
| +-----------------------------------+ | ||
| |[[1.0,1.0,1.0], 1] | | ||
| +-----------------------------------+ | ||
| <BLANKLINE> | ||
| >>> df.select(summarizer.summary(df.features)).show(truncate=False) | ||
| +--------------------------------+ | ||
| |aggregate_metrics(features, 1.0)| | ||
| +--------------------------------+ | ||
| |[[1.0,1.5,2.0], 2] | | ||
| +--------------------------------+ | ||
| <BLANKLINE> | ||
| >>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) | ||
| +--------------+ | ||
| |mean(features)| | ||
| +--------------+ | ||
| |[1.0,1.0,1.0] | | ||
| +--------------+ | ||
| <BLANKLINE> | ||
| >>> df.select(Summarizer.mean(df.features)).show(truncate=False) | ||
| +--------------+ | ||
| |mean(features)| | ||
| +--------------+ | ||
| |[1.0,1.5,2.0] | | ||
| +--------------+ | ||
| <BLANKLINE> | ||
|
|
||
| .. versionadded:: 2.4.0 | ||
|
|
||
| """ | ||
| @staticmethod | ||
| @since("2.4.0") | ||
| def mean(col, weightCol=None): | ||
| """ | ||
| return a column of mean summary | ||
| """ | ||
| return Summarizer._get_single_metric(col, weightCol, "mean") | ||
|
|
||
| @staticmethod | ||
| @since("2.4.0") | ||
| def variance(col, weightCol=None): | ||
| """ | ||
| return a column of variance summary | ||
| """ | ||
| return Summarizer._get_single_metric(col, weightCol, "variance") | ||
|
|
||
| @staticmethod | ||
| @since("2.4.0") | ||
| def count(col, weightCol=None): | ||
| """ | ||
| return a column of count summary | ||
| """ | ||
| return Summarizer._get_single_metric(col, weightCol, "count") | ||
|
|
||
| @staticmethod | ||
| @since("2.4.0") | ||
| def numNonZeros(col, weightCol=None): | ||
| """ | ||
| return a column of numNonZero summary | ||
| """ | ||
| return Summarizer._get_single_metric(col, weightCol, "numNonZeros") | ||
|
|
||
| @staticmethod | ||
| @since("2.4.0") | ||
| def max(col, weightCol=None): | ||
| """ | ||
| return a column of max summary | ||
| """ | ||
| return Summarizer._get_single_metric(col, weightCol, "max") | ||
|
|
||
| @staticmethod | ||
| @since("2.4.0") | ||
| def min(col, weightCol=None): | ||
| """ | ||
| return a column of min summary | ||
| """ | ||
| return Summarizer._get_single_metric(col, weightCol, "min") | ||
|
|
||
| @staticmethod | ||
| @since("2.4.0") | ||
| def normL1(col, weightCol=None): | ||
| """ | ||
| return a column of normL1 summary | ||
| """ | ||
| return Summarizer._get_single_metric(col, weightCol, "normL1") | ||
|
|
||
| @staticmethod | ||
| @since("2.4.0") | ||
| def normL2(col, weightCol=None): | ||
| """ | ||
| return a column of normL2 summary | ||
| """ | ||
| return Summarizer._get_single_metric(col, weightCol, "normL2") | ||
|
|
||
| @staticmethod | ||
| def _check_param(featureCol, weightCol): | ||
| if weightCol is None: | ||
| weightCol = lit(1.0) | ||
| if not isinstance(featureCol, Column) or not isinstance(weightCol, Column): | ||
| raise TypeError("featureCol and weightCol should be a Column") | ||
| return featureCol, weightCol | ||
|
|
||
| @staticmethod | ||
| def _get_single_metric(col, weightCol, metric): | ||
| col, weightCol = Summarizer._check_param(col, weightCol) | ||
| return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric, | ||
| col._jc, weightCol._jc)) | ||
|
|
||
| @staticmethod | ||
| @since("2.4.0") | ||
| def metrics(*metrics): | ||
| """ | ||
| Given a list of metrics, provides a builder that it turns computes metrics from a column. | ||
|
|
||
| See the documentation of [[Summarizer]] for an example. | ||
|
|
||
| The following metrics are accepted (case sensitive): | ||
| - mean: a vector that contains the coefficient-wise mean. | ||
| - variance: a vector tha contains the coefficient-wise variance. | ||
| - count: the count of all vectors seen. | ||
| - numNonzeros: a vector with the number of non-zeros for each coefficients | ||
| - max: the maximum for each coefficient. | ||
| - min: the minimum for each coefficient. | ||
| - normL2: the Euclidian norm for each coefficient. | ||
| - normL1: the L1 norm of each coefficient (sum of the absolute values). | ||
|
|
||
| :param metrics metrics that can be provided. | ||
| :return a Summarizer | ||
|
|
||
| Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD | ||
| interface. | ||
| """ | ||
| sc = SparkContext._active_spark_context | ||
| js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics", | ||
| _to_seq(sc, metrics)) | ||
| return SummarizerBuilder(js) | ||
|
|
||
|
|
||
| class SummarizerBuilder(object): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| A builder object that provides summary statistics about a given column. | ||
|
|
||
| Users should not directly create such builders, but instead use one of the methods in | ||
| :py:class:`pyspark.ml.stat.Summary` | ||
|
|
||
| .. versionadded:: 2.4.0 | ||
|
|
||
| """ | ||
| def __init__(self, js): | ||
| self._js = js | ||
|
||
|
|
||
| @since("2.4.0") | ||
| def summary(self, featureCol, weightCol=None): | ||
|
||
| """ | ||
| Returns an aggregate object that contains the summary of the column with the requested | ||
| metrics. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's copy the docs for arguments & return value from Scala |
||
| """ | ||
| featureCol, weightCol = Summarizer._check_param(featureCol, weightCol) | ||
| return Column(self._js.summary(featureCol._jc, weightCol._jc)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import doctest | ||
| import pyspark.ml.stat | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name needs to match its Scala equivalent: "SummaryBuilder"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, shouldn't we use JavaWrapper for this? That will clean up when this object is destroyed.