Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
small updates
  • Loading branch information
WeichenXu123 committed Aug 15, 2017
commit 7540c4c7ca53fd641ff8c15ee347352faa57fad9
94 changes: 48 additions & 46 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -62,7 +61,7 @@ abstract class SummaryBuilder {
* {{{
* val dataframe = ... // Some dataframe containing a feature column
* val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features"))
* val Row(min_, max_) = allStats.first()
* val Row(Row(min_, max_)) = allStats.first()
* }}}
*
* If one wants to get a single metric, shortcuts are also available:
Expand Down Expand Up @@ -107,20 +106,28 @@ object Summarizer extends Logging {
new SummaryBuilderImpl(typedMetrics, computeMetrics)
}

@Since("2.3.0")
def mean(col: Column): Column = getSingleMetric(col, "mean")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Since.


@Since("2.3.0")
def variance(col: Column): Column = getSingleMetric(col, "variance")

@Since("2.3.0")
def count(col: Column): Column = getSingleMetric(col, "count")

@Since("2.3.0")
def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros")

@Since("2.3.0")
def max(col: Column): Column = getSingleMetric(col, "max")

@Since("2.3.0")
def min(col: Column): Column = getSingleMetric(col, "min")

@Since("2.3.0")
def normL1(col: Column): Column = getSingleMetric(col, "normL1")

@Since("2.3.0")
def normL2(col: Column): Column = getSingleMetric(col, "normL2")

private def getSingleMetric(col: Column, metric: String): Column = {
Expand All @@ -130,8 +137,8 @@ object Summarizer extends Logging {
}

private[ml] class SummaryBuilderImpl(
requestedMetrics: Seq[SummaryBuilderImpl.Metrics],
requestedCompMetrics: Seq[SummaryBuilderImpl.ComputeMetrics]
requestedMetrics: Seq[SummaryBuilderImpl.Metric],
requestedCompMetrics: Seq[SummaryBuilderImpl.ComputeMetric]
) extends SummaryBuilder {

override def summary(featuresCol: Column, weightCol: Column): Column = {
Expand All @@ -154,9 +161,9 @@ object SummaryBuilderImpl extends Logging {
def implementedMetrics: Seq[String] = allMetrics.map(_._1).sorted

@throws[IllegalArgumentException]("When the list is empty or not a subset of known metrics")
def getRelevantMetrics(requested: Seq[String]): (Seq[Metrics], Seq[ComputeMetrics]) = {
def getRelevantMetrics(requested: Seq[String]): (Seq[Metric], Seq[ComputeMetric]) = {
val all = requested.map { req =>
val (_, metric, _, deps) = allMetrics.find(tup => tup._1 == req).getOrElse {
val (_, metric, _, deps) = allMetrics.find(_._1 == req).getOrElse {
throw new IllegalArgumentException(s"Metric $req cannot be found." +
s" Valid metrics are $implementedMetrics")
}
Expand All @@ -169,10 +176,12 @@ object SummaryBuilderImpl extends Logging {
metrics -> computeMetrics
}

def structureForMetrics(metrics: Seq[Metrics]): StructType = {
val dct = allMetrics.map { case (n, m, dt, _) => (m, (n, dt)) }.toMap
val fields = metrics.map(dct.apply).map { case (n, dt) =>
StructField(n, dt, nullable = false)
def structureForMetrics(metrics: Seq[Metric]): StructType = {
val dict = allMetrics.map { case (name, metric, dataType, _) =>
(metric, (name, dataType))
}.toMap
val fields = metrics.map(dict.apply).map { case (name, dataType) =>
StructField(name, dataType, nullable = false)
}
StructType(fields)
}
Expand All @@ -186,7 +195,7 @@ object SummaryBuilderImpl extends Logging {
* This list associates the user name, the internal (typed) name, and the list of computation
* metrics that need to de computed internally to get the final result.
*/
private val allMetrics: Seq[(String, Metrics, DataType, Seq[ComputeMetrics])] = Seq(
private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq(
("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)),
("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
("count", Count, LongType, Seq()),
Expand All @@ -200,34 +209,34 @@ object SummaryBuilderImpl extends Logging {
/**
* The metrics that are currently implemented.
*/
sealed trait Metrics extends Serializable
case object Mean extends Metrics
case object Variance extends Metrics
case object Count extends Metrics
case object NumNonZeros extends Metrics
case object Max extends Metrics
case object Min extends Metrics
case object NormL2 extends Metrics
case object NormL1 extends Metrics
sealed trait Metric extends Serializable
private[stat] case object Mean extends Metric
private[stat] case object Variance extends Metric
private[stat] case object Count extends Metric
private[stat] case object NumNonZeros extends Metric
private[stat] case object Max extends Metric
private[stat] case object Min extends Metric
private[stat] case object NormL2 extends Metric
private[stat] case object NormL1 extends Metric

/**
* The running metrics that are going to be computed.
*
* There is a bipartite graph between the metrics and the computed metrics.
*/
sealed trait ComputeMetrics extends Serializable
case object ComputeMean extends ComputeMetrics
case object ComputeM2n extends ComputeMetrics
case object ComputeM2 extends ComputeMetrics
case object ComputeL1 extends ComputeMetrics
case object ComputeWeightSum extends ComputeMetrics
case object ComputeNNZ extends ComputeMetrics
case object ComputeMax extends ComputeMetrics
case object ComputeMin extends ComputeMetrics
sealed trait ComputeMetric extends Serializable
private[stat] case object ComputeMean extends ComputeMetric
private[stat] case object ComputeM2n extends ComputeMetric
private[stat] case object ComputeM2 extends ComputeMetric
private[stat] case object ComputeL1 extends ComputeMetric
private[stat] case object ComputeWeightSum extends ComputeMetric
private[stat] case object ComputeNNZ extends ComputeMetric
private[stat] case object ComputeMax extends ComputeMetric
private[stat] case object ComputeMin extends ComputeMetric

private[stat] class SummarizerBuffer(
requestedMetrics: Seq[Metrics],
requestedCompMetrics: Seq[ComputeMetrics]
requestedMetrics: Seq[Metric],
requestedCompMetrics: Seq[ComputeMetric]
) extends Serializable {

private var n = 0
Expand Down Expand Up @@ -255,7 +264,7 @@ object SummaryBuilderImpl extends Logging {
* Add a new sample to this summarizer, and update the statistical summary.
*/
def add(instance: Vector, weight: Double): this.type = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the usage of this Vector, I think we can directly work on the serialized Vector data (size, indices, values). It should reduce much of time on deserialization of Vector.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya I have tried your suggestion in the previous version code, but it do not bring performance advantage.
You can check my previous version code (in this commit "optimize summarizer buffer") and run tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I saw it. Thanks.

Copy link
Contributor Author

@WeichenXu123 WeichenXu123 Aug 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If directly work on serialized data (UnsafeArrayData), it only avoid the array copy(which save little time), but brings extra cost when calling UnsafeArrayData.getDouble)
and it will increase code complexity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask how the performance test runs? Especially for the RDD part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya
modify ignore("performance test") to test("performance test")
then run test: SummarizerSuite

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I didn't review the test thoroughly.

require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0")
require(weight >= 0.0, s"sample weight, $weight has to be >= 0.0")
if (weight == 0.0) return this

if (n == 0) {
Expand Down Expand Up @@ -510,16 +519,16 @@ object SummaryBuilderImpl extends Logging {
}

private case class MetricsAggregate(
requestedMetrics: Seq[Metrics],
requestedComputeMetrics: Seq[ComputeMetrics],
requestedMetrics: Seq[Metric],
requestedComputeMetrics: Seq[ComputeMetric],
featuresExpr: Expression,
weightExpr: Expression,
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends TypedImperativeAggregate[SummarizerBuffer] {

override def eval(state: SummarizerBuffer): InternalRow = {
val metrics = requestedMetrics.map({
val metrics = requestedMetrics.map {
case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray)
case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray)
case Count => state.count
Expand All @@ -529,30 +538,24 @@ object SummaryBuilderImpl extends Logging {
case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray)
case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray)
case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray)
})
}
InternalRow.apply(metrics: _*)
}

override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil

override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
// val features = udt.deserialize(featuresExpr.eval(row))
val featuresDatum = featuresExpr.eval(row).asInstanceOf[InternalRow]

val features = udt.deserialize(featuresDatum)

val features = udt.deserialize(featuresExpr.eval(row))
val weight = weightExpr.eval(row).asInstanceOf[Double]

state.add(features, weight)
state
}

override def merge(state: SummarizerBuffer,
other: SummarizerBuffer): SummarizerBuffer = {
other: SummarizerBuffer): SummarizerBuffer = {
state.merge(other)
}


override def nullable: Boolean = false

override def createAggregationBuffer(): SummarizerBuffer
Expand Down Expand Up @@ -589,5 +592,4 @@ object SummaryBuilderImpl extends Logging {

private[this] val udt = new VectorUDT
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some better way to get the object of VectorUDT ? cc @cloud-fan


}

}
Loading