Skip to content
Closed
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 182 additions & 1 deletion python/pyspark/ml/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from pyspark import since, SparkContext
from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.wrapper import _jvm
from pyspark.ml.wrapper import JavaWrapper, _jvm
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.functions import lit


class ChiSquareTest(object):
Expand Down Expand Up @@ -195,6 +197,185 @@ def test(dataset, sampleCol, distName, *params):
_jvm().PythonUtils.toSeq(params)))


class Summarizer(object):
"""
.. note:: Experimental

Tools for vectorized statistics on MLlib Vectors.
The methods in this package provide various statistics for Vectors contained inside DataFrames.
This class lets users pick the statistics they would like to extract for a given column.

>>> from pyspark.ml.stat import Summarizer
>>> from pyspark.sql import Row
>>> from pyspark.ml.linalg import Vectors
>>> summarizer = Summarizer.metrics("mean", "count")
>>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
>>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)
+-----------------------------------+
|aggregate_metrics(features, weight)|
+-----------------------------------+
|[[1.0,1.0,1.0], 1] |
+-----------------------------------+
<BLANKLINE>
>>> df.select(summarizer.summary(df.features)).show(truncate=False)
+--------------------------------+
|aggregate_metrics(features, 1.0)|
+--------------------------------+
|[[1.0,1.5,2.0], 2] |
+--------------------------------+
<BLANKLINE>
>>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False)
+--------------+
|mean(features)|
+--------------+
|[1.0,1.0,1.0] |
+--------------+
<BLANKLINE>
>>> df.select(Summarizer.mean(df.features)).show(truncate=False)
+--------------+
|mean(features)|
+--------------+
|[1.0,1.5,2.0] |
+--------------+
<BLANKLINE>

.. versionadded:: 2.4.0

"""
@staticmethod
@since("2.4.0")
def mean(col, weightCol=None):
"""
return a column of mean summary
"""
return Summarizer._get_single_metric(col, weightCol, "mean")

@staticmethod
@since("2.4.0")
def variance(col, weightCol=None):
"""
return a column of variance summary
"""
return Summarizer._get_single_metric(col, weightCol, "variance")

@staticmethod
@since("2.4.0")
def count(col, weightCol=None):
"""
return a column of count summary
"""
return Summarizer._get_single_metric(col, weightCol, "count")

@staticmethod
@since("2.4.0")
def numNonZeros(col, weightCol=None):
"""
return a column of numNonZero summary
"""
return Summarizer._get_single_metric(col, weightCol, "numNonZeros")

@staticmethod
@since("2.4.0")
def max(col, weightCol=None):
"""
return a column of max summary
"""
return Summarizer._get_single_metric(col, weightCol, "max")

@staticmethod
@since("2.4.0")
def min(col, weightCol=None):
"""
return a column of min summary
"""
return Summarizer._get_single_metric(col, weightCol, "min")

@staticmethod
@since("2.4.0")
def normL1(col, weightCol=None):
"""
return a column of normL1 summary
"""
return Summarizer._get_single_metric(col, weightCol, "normL1")

@staticmethod
@since("2.4.0")
def normL2(col, weightCol=None):
"""
return a column of normL2 summary
"""
return Summarizer._get_single_metric(col, weightCol, "normL2")

@staticmethod
def _check_param(featureCol, weightCol):
if weightCol is None:
weightCol = lit(1.0)
if not isinstance(featureCol, Column) or not isinstance(weightCol, Column):
raise TypeError("featureCol and weightCol should be a Column")
return featureCol, weightCol

@staticmethod
def _get_single_metric(col, weightCol, metric):
col, weightCol = Summarizer._check_param(col, weightCol)
return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric,
col._jc, weightCol._jc))

@staticmethod
@since("2.4.0")
def metrics(*metrics):
"""
Given a list of metrics, provides a builder that it turns computes metrics from a column.

See the documentation of [[Summarizer]] for an example.

The following metrics are accepted (case sensitive):
- mean: a vector that contains the coefficient-wise mean.
- variance: a vector tha contains the coefficient-wise variance.
- count: the count of all vectors seen.
- numNonzeros: a vector with the number of non-zeros for each coefficients
- max: the maximum for each coefficient.
- min: the minimum for each coefficient.
- normL2: the Euclidian norm for each coefficient.
- normL1: the L1 norm of each coefficient (sum of the absolute values).

:param metrics metrics that can be provided.
:return a Summarizer

Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
interface.
"""
sc = SparkContext._active_spark_context
js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics",
_to_seq(sc, metrics))
return SummarizerBuilder(js)


class SummarizerBuilder(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name needs to match its Scala equivalent: "SummaryBuilder"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, shouldn't we use JavaWrapper for this? That will clean up when this object is destroyed.

"""
.. note:: Experimental

A builder object that provides summary statistics about a given column.

Users should not directly create such builders, but instead use one of the methods in
:py:class:`pyspark.ml.stat.Summary`

.. versionadded:: 2.4.0

"""
def __init__(self, js):
self._js = js
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should call the super's init method, and it should store js in _java_obj (which is set in the JavaWrapper init).


@since("2.4.0")
def summary(self, featureCol, weightCol=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to move the "summary" method into another class, and have Summary only contain static methods. That will help with autocomplete so that it's clear that you're not meant to do Summery.metrics("min").mean(features).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: naming should match Scala: "featuresCol"

"""
Returns an aggregate object that contains the summary of the column with the requested
metrics.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's copy the docs for arguments & return value from Scala

"""
featureCol, weightCol = Summarizer._check_param(featureCol, weightCol)
return Column(self._js.summary(featureCol._jc, weightCol._jc))


if __name__ == "__main__":
import doctest
import pyspark.ml.stat
Expand Down