-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19634][ML] Multivariate summarizer - dataframes API #18798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
281d546
2860390
4f32e27
6053d0e
b02db42
7540c4c
b081fc3
c82958f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,6 @@ import org.apache.spark.sql.Column | |
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate} | ||
| import org.apache.spark.sql.catalyst.util.ArrayData | ||
| import org.apache.spark.sql.functions.lit | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
|
|
@@ -62,7 +61,7 @@ abstract class SummaryBuilder { | |
| * {{{ | ||
| * val dataframe = ... // Some dataframe containing a feature column | ||
| * val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features")) | ||
| * val Row(min_, max_) = allStats.first() | ||
| * val Row(Row(min_, max_)) = allStats.first() | ||
| * }}} | ||
| * | ||
| * If one wants to get a single metric, shortcuts are also available: | ||
|
|
@@ -107,20 +106,28 @@ object Summarizer extends Logging { | |
| new SummaryBuilderImpl(typedMetrics, computeMetrics) | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| def mean(col: Column): Column = getSingleMetric(col, "mean") | ||
|
|
||
| @Since("2.3.0") | ||
| def variance(col: Column): Column = getSingleMetric(col, "variance") | ||
|
|
||
| @Since("2.3.0") | ||
| def count(col: Column): Column = getSingleMetric(col, "count") | ||
|
|
||
| @Since("2.3.0") | ||
| def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros") | ||
|
|
||
| @Since("2.3.0") | ||
| def max(col: Column): Column = getSingleMetric(col, "max") | ||
|
|
||
| @Since("2.3.0") | ||
| def min(col: Column): Column = getSingleMetric(col, "min") | ||
|
|
||
| @Since("2.3.0") | ||
| def normL1(col: Column): Column = getSingleMetric(col, "normL1") | ||
|
|
||
| @Since("2.3.0") | ||
| def normL2(col: Column): Column = getSingleMetric(col, "normL2") | ||
|
|
||
| private def getSingleMetric(col: Column, metric: String): Column = { | ||
|
|
@@ -130,8 +137,8 @@ object Summarizer extends Logging { | |
| } | ||
|
|
||
| private[ml] class SummaryBuilderImpl( | ||
| requestedMetrics: Seq[SummaryBuilderImpl.Metrics], | ||
| requestedCompMetrics: Seq[SummaryBuilderImpl.ComputeMetrics] | ||
| requestedMetrics: Seq[SummaryBuilderImpl.Metric], | ||
| requestedCompMetrics: Seq[SummaryBuilderImpl.ComputeMetric] | ||
| ) extends SummaryBuilder { | ||
|
|
||
| override def summary(featuresCol: Column, weightCol: Column): Column = { | ||
|
|
@@ -154,9 +161,9 @@ object SummaryBuilderImpl extends Logging { | |
| def implementedMetrics: Seq[String] = allMetrics.map(_._1).sorted | ||
|
|
||
| @throws[IllegalArgumentException]("When the list is empty or not a subset of known metrics") | ||
| def getRelevantMetrics(requested: Seq[String]): (Seq[Metrics], Seq[ComputeMetrics]) = { | ||
| def getRelevantMetrics(requested: Seq[String]): (Seq[Metric], Seq[ComputeMetric]) = { | ||
| val all = requested.map { req => | ||
| val (_, metric, _, deps) = allMetrics.find(tup => tup._1 == req).getOrElse { | ||
| val (_, metric, _, deps) = allMetrics.find(_._1 == req).getOrElse { | ||
| throw new IllegalArgumentException(s"Metric $req cannot be found." + | ||
| s" Valid metrics are $implementedMetrics") | ||
| } | ||
|
|
@@ -169,10 +176,12 @@ object SummaryBuilderImpl extends Logging { | |
| metrics -> computeMetrics | ||
| } | ||
|
|
||
| def structureForMetrics(metrics: Seq[Metrics]): StructType = { | ||
| val dct = allMetrics.map { case (n, m, dt, _) => (m, (n, dt)) }.toMap | ||
| val fields = metrics.map(dct.apply).map { case (n, dt) => | ||
| StructField(n, dt, nullable = false) | ||
| def structureForMetrics(metrics: Seq[Metric]): StructType = { | ||
| val dict = allMetrics.map { case (name, metric, dataType, _) => | ||
| (metric, (name, dataType)) | ||
| }.toMap | ||
| val fields = metrics.map(dict.apply).map { case (name, dataType) => | ||
| StructField(name, dataType, nullable = false) | ||
| } | ||
| StructType(fields) | ||
| } | ||
|
|
@@ -186,7 +195,7 @@ object SummaryBuilderImpl extends Logging { | |
| * This list associates the user name, the internal (typed) name, and the list of computation | ||
| * metrics that need to de computed internally to get the final result. | ||
| */ | ||
| private val allMetrics: Seq[(String, Metrics, DataType, Seq[ComputeMetrics])] = Seq( | ||
| private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq( | ||
| ("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)), | ||
| ("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)), | ||
| ("count", Count, LongType, Seq()), | ||
|
|
@@ -200,34 +209,34 @@ object SummaryBuilderImpl extends Logging { | |
| /** | ||
| * The metrics that are currently implemented. | ||
| */ | ||
| sealed trait Metrics extends Serializable | ||
| case object Mean extends Metrics | ||
| case object Variance extends Metrics | ||
| case object Count extends Metrics | ||
| case object NumNonZeros extends Metrics | ||
| case object Max extends Metrics | ||
| case object Min extends Metrics | ||
| case object NormL2 extends Metrics | ||
| case object NormL1 extends Metrics | ||
| sealed trait Metric extends Serializable | ||
| private[stat] case object Mean extends Metric | ||
| private[stat] case object Variance extends Metric | ||
| private[stat] case object Count extends Metric | ||
| private[stat] case object NumNonZeros extends Metric | ||
| private[stat] case object Max extends Metric | ||
| private[stat] case object Min extends Metric | ||
| private[stat] case object NormL2 extends Metric | ||
| private[stat] case object NormL1 extends Metric | ||
|
|
||
| /** | ||
| * The running metrics that are going to be computed. | ||
| * | ||
| * There is a bipartite graph between the metrics and the computed metrics. | ||
| */ | ||
| sealed trait ComputeMetrics extends Serializable | ||
| case object ComputeMean extends ComputeMetrics | ||
| case object ComputeM2n extends ComputeMetrics | ||
| case object ComputeM2 extends ComputeMetrics | ||
| case object ComputeL1 extends ComputeMetrics | ||
| case object ComputeWeightSum extends ComputeMetrics | ||
| case object ComputeNNZ extends ComputeMetrics | ||
| case object ComputeMax extends ComputeMetrics | ||
| case object ComputeMin extends ComputeMetrics | ||
| sealed trait ComputeMetric extends Serializable | ||
| private[stat] case object ComputeMean extends ComputeMetric | ||
| private[stat] case object ComputeM2n extends ComputeMetric | ||
| private[stat] case object ComputeM2 extends ComputeMetric | ||
| private[stat] case object ComputeL1 extends ComputeMetric | ||
| private[stat] case object ComputeWeightSum extends ComputeMetric | ||
| private[stat] case object ComputeNNZ extends ComputeMetric | ||
| private[stat] case object ComputeMax extends ComputeMetric | ||
| private[stat] case object ComputeMin extends ComputeMetric | ||
|
|
||
| private[stat] class SummarizerBuffer( | ||
| requestedMetrics: Seq[Metrics], | ||
| requestedCompMetrics: Seq[ComputeMetrics] | ||
| requestedMetrics: Seq[Metric], | ||
| requestedCompMetrics: Seq[ComputeMetric] | ||
| ) extends Serializable { | ||
|
|
||
| private var n = 0 | ||
|
|
@@ -255,7 +264,7 @@ object SummaryBuilderImpl extends Logging { | |
| * Add a new sample to this summarizer, and update the statistical summary. | ||
| */ | ||
| def add(instance: Vector, weight: Double): this.type = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the usage of this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @viirya I have tried your suggestion in the previous version code, but it do not bring performance advantage.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I saw it. Thanks.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If directly work on serialized data (UnsafeArrayData), it only avoid the array copy(which save little time), but brings extra cost when calling
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I ask how the performance test runs? Especially for the RDD part.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @viirya
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I didn't review the test thoroughly. |
||
| require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0") | ||
| require(weight >= 0.0, s"sample weight, $weight has to be >= 0.0") | ||
| if (weight == 0.0) return this | ||
|
|
||
| if (n == 0) { | ||
|
|
@@ -510,16 +519,16 @@ object SummaryBuilderImpl extends Logging { | |
| } | ||
|
|
||
| private case class MetricsAggregate( | ||
| requestedMetrics: Seq[Metrics], | ||
| requestedComputeMetrics: Seq[ComputeMetrics], | ||
| requestedMetrics: Seq[Metric], | ||
| requestedComputeMetrics: Seq[ComputeMetric], | ||
| featuresExpr: Expression, | ||
| weightExpr: Expression, | ||
| mutableAggBufferOffset: Int, | ||
| inputAggBufferOffset: Int) | ||
| extends TypedImperativeAggregate[SummarizerBuffer] { | ||
|
|
||
| override def eval(state: SummarizerBuffer): InternalRow = { | ||
| val metrics = requestedMetrics.map({ | ||
| val metrics = requestedMetrics.map { | ||
| case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray) | ||
| case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray) | ||
| case Count => state.count | ||
|
|
@@ -529,30 +538,24 @@ object SummaryBuilderImpl extends Logging { | |
| case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray) | ||
| case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray) | ||
| case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray) | ||
| }) | ||
| } | ||
| InternalRow.apply(metrics: _*) | ||
| } | ||
|
|
||
| override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil | ||
|
|
||
| override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = { | ||
| // val features = udt.deserialize(featuresExpr.eval(row)) | ||
| val featuresDatum = featuresExpr.eval(row).asInstanceOf[InternalRow] | ||
|
|
||
| val features = udt.deserialize(featuresDatum) | ||
|
|
||
| val features = udt.deserialize(featuresExpr.eval(row)) | ||
| val weight = weightExpr.eval(row).asInstanceOf[Double] | ||
|
|
||
| state.add(features, weight) | ||
| state | ||
| } | ||
|
|
||
| override def merge(state: SummarizerBuffer, | ||
| other: SummarizerBuffer): SummarizerBuffer = { | ||
| other: SummarizerBuffer): SummarizerBuffer = { | ||
| state.merge(other) | ||
| } | ||
|
|
||
|
|
||
| override def nullable: Boolean = false | ||
|
|
||
| override def createAggregationBuffer(): SummarizerBuffer | ||
|
|
@@ -589,5 +592,4 @@ object SummaryBuilderImpl extends Logging { | |
|
|
||
| private[this] val udt = new VectorUDT | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there some better way to get the object of |
||
|
|
||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
Since.