Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d1e22d5
tuning summary
YY-OnCall Dec 3, 2016
a7cfa63
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Dec 5, 2016
ad73c12
add ut
YY-OnCall Dec 5, 2016
425a419
add comments
YY-OnCall Dec 5, 2016
bd18c00
get default spark session
YY-OnCall Dec 13, 2016
2a0af1d
resolve merge conflict
YY-OnCall Feb 22, 2017
1a594d0
merge conflict
YY-OnCall Jul 5, 2017
0a698fe
support cross validation
YY-OnCall Jul 5, 2017
bd459b1
update version
YY-OnCall Jul 5, 2017
4e3e19c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 24, 2017
c0bc81a
improve unit test
YY-OnCall Jul 24, 2017
bbf3f9f
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Aug 8, 2017
b6a7c53
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Aug 9, 2017
72aea62
remove TuningSummary
YY-OnCall Aug 9, 2017
91da358
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Sep 10, 2017
297091f
update for pipeline
YY-OnCall Sep 11, 2017
670467a
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Nov 19, 2017
36b1dd5
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Dec 29, 2017
5844e0c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jan 1, 2018
8c829b5
merge conflict
YY-OnCall Jul 23, 2018
4aaf8e5
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 23, 2018
ceaad1c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 27, 2018
41c4c12
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 29, 2018
4aef3aa
remove sort add comments
YY-OnCall Jul 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ object ModelSelectionViaCrossValidationExample {
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}
cvModel.tuningSummary.show()
// $example off$

spark.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ object ModelSelectionViaTrainValidationSplitExample {
model.transform(test)
.select("features", "label", "prediction")
.show()
model.tuningSummary.show()
// $example off$

spark.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ class CrossValidatorModel private[ml] (
bestModel.transformSchema(schema)
}

/**
* Summary of grid search tuning in the format of DataFrame. Each row contains one candidate
* paramMap and the corresponding metric of trained model.
*/
@Since("2.3.0")
lazy val tuningSummary: DataFrame = this.getTuningSummaryDF(avgMetrics)

@Since("1.4.0")
override def copy(extra: ParamMap): CrossValidatorModel = {
val copied = new CrossValidatorModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ class TrainValidationSplitModel private[ml] (
bestModel.transformSchema(schema)
}

/**
* Summary of grid search tuning in the format of DataFrame. Each row contains one candidate
* paramMap and the corresponding metric of trained model.
*/
@Since("2.3.0")
lazy val tuningSummary: DataFrame = this.getTuningSummaryDF(validationMetrics)

@Since("1.5.0")
override def copy(extra: ParamMap): TrainValidationSplitModel = {
val copied = new TrainValidationSplitModel (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

/**
* Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
Expand Down Expand Up @@ -85,6 +86,32 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
}


/**
* Summary of grid search tuning in the format of DataFrame. Each row contains one candidate
* paramMap and the corresponding metric of trained model.
*/
protected def getTuningSummaryDF(metrics: Array[Double]): DataFrame = {
val params = $(estimatorParamMaps)
require(params.nonEmpty, "estimator param maps should not be empty")
require(params.length == metrics.length, "estimator param maps number should match metrics")
val metricName = $(evaluator) match {
case b: BinaryClassificationEvaluator => b.getMetricName
case m: MulticlassClassificationEvaluator => m.getMetricName
case r: RegressionEvaluator => r.getMetricName
case _ => "metrics"
}
val spark = SparkSession.builder().getOrCreate()
val sc = spark.sparkContext
val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++ Seq(metricName)
val schema = new StructType(fields.map(name => StructField(name, StringType)).toArray)
val rows = sc.parallelize(params.zip(metrics)).map { case (param, metric) =>
val values = param.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here seems exists a problem:
Suppose params(0) (which is a ParamMap) contains ParamA and ParamB,
and params(1) (which is a ParamMap) contains ParamA and ParamC,
The code here will run into problems. Because you compose the row values sorted by param name but do not check whether every row exactly match the first row.
I think better way is, go though the whole ParamMap list and collect all params used, and sort them by name, as the dataframe schema.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here use param_value.toString, some array type param will convert to unreadable string.
For example, DoubleArrayParam, doubleArray.toString will became "[DXXXXX"
use Param.jsonEncode is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we should support the case for custom paramMap.

Row.fromSeq(values)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the var names is a little confusing,
params ==> paramMaps
case (param, metric) ==> case (paramMap, metric)
will be more clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

spark.createDataFrame(rows, schema)
}
}

private[ml] object ValidatorParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.StructType

class CrossValidatorSuite
Expand Down Expand Up @@ -66,6 +66,26 @@ class CrossValidatorSuite
assert(cvModel.avgMetrics.length === lrParamMaps.length)
}

test("cross validation with tuning summary") {
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1.0, 1000.0))
.addGrid(lr.maxIter, Array(0, 2))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new CrossValidator()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(3)
val cvModel = cv.fit(dataset)
val expected = lrParamMaps.zip(cvModel.avgMetrics).map { case (map, metric) =>
Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString))
}
assert(cvModel.tuningSummary.collect().toSet === expected.toSet)
assert(cvModel.tuningSummary.columns.last === eval.getMetricName)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a test for the exception being thrown if no summary?

test("cross validation with linear regression") {
val dataset = sc.parallelize(
LinearDataGenerator.generateLinearInput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.StructType

class TrainValidationSplitSuite
Expand Down Expand Up @@ -59,6 +59,26 @@ class TrainValidationSplitSuite
assert(tvsModel.validationMetrics.length === lrParamMaps.length)
}

test("train validation split with tuning summary") {
val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1.0, 1000.0))
.addGrid(lr.maxIter, Array(0, 2))
.build()
val eval = new BinaryClassificationEvaluator
val tvs = new TrainValidationSplit()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
val tvsModel = tvs.fit(dataset)
val expected = lrParamMaps.zip(tvsModel.validationMetrics).map { case (map, metric) =>
Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString))
}
assert(tvsModel.tuningSummary.collect().toSet === expected.toSet)
assert(tvsModel.tuningSummary.columns.last === eval.getMetricName)
}

test("train validation with linear regression") {
val dataset = sc.parallelize(
LinearDataGenerator.generateLinearInput(
Expand Down Expand Up @@ -86,7 +106,7 @@ class TrainValidationSplitSuite
assert(parent.getMaxIter === 10)
assert(tvsModel.validationMetrics.length === lrParamMaps.length)

eval.setMetricName("r2")
eval.setMetricName("r2")
val tvsModel2 = tvs.fit(dataset)
val parent2 = tvsModel2.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent2.getRegParam === 0.001)
Expand Down