Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d1e22d5
tuning summary
YY-OnCall Dec 3, 2016
a7cfa63
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Dec 5, 2016
ad73c12
add ut
YY-OnCall Dec 5, 2016
425a419
add comments
YY-OnCall Dec 5, 2016
bd18c00
get default spark session
YY-OnCall Dec 13, 2016
2a0af1d
resolve merge conflict
YY-OnCall Feb 22, 2017
1a594d0
merge conflict
YY-OnCall Jul 5, 2017
0a698fe
support cross validation
YY-OnCall Jul 5, 2017
bd459b1
update version
YY-OnCall Jul 5, 2017
4e3e19c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 24, 2017
c0bc81a
improve unit test
YY-OnCall Jul 24, 2017
bbf3f9f
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Aug 8, 2017
b6a7c53
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Aug 9, 2017
72aea62
remove TuningSummary
YY-OnCall Aug 9, 2017
91da358
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Sep 10, 2017
297091f
update for pipeline
YY-OnCall Sep 11, 2017
670467a
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Nov 19, 2017
36b1dd5
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Dec 29, 2017
5844e0c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jan 1, 2018
8c829b5
merge conflict
YY-OnCall Jul 23, 2018
4aaf8e5
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 23, 2018
ceaad1c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 27, 2018
41c4c12
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 29, 2018
4aef3aa
remove sort add comments
YY-OnCall Jul 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve unit test
  • Loading branch information
YY-OnCall committed Jul 24, 2017
commit c0bc81a976bb00d1ad5b31164f61df7fa971cc91
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,22 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType}

/**
* :: Experimental ::
* Summary of grid search tuning.
* Summary for the grid search tuning.
*
* @param params estimator param maps
* @param metrics Corresponding evaluation metrics for the param maps
* @param params ParamMaps for the Estimator
* @param metrics corresponding evaluation metrics for the params
* @param bestIndex index in params for the ParamMap of the best model.
*/
@Since("2.3.0")
@Experimental
private[tuning] class TuningSummary private[tuning](
val params: Array[ParamMap],
val metrics: Array[Double],
val bestIndex: Int) {
private[tuning] val params: Array[ParamMap],
private[tuning] val metrics: Array[Double],
private[tuning] val bestIndex: Int) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears bestIndex is never used?


/**
* Summary of grid search tuning in the format of DataFrame. Each row contains one candidate
* paramMap and its corresponding metrics.
* paramMap and its corresponding metric.
*/
def trainingMetrics: DataFrame = {
require(params.nonEmpty, "estimator param maps should not be empty")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.StructType

class CrossValidatorSuite
Expand Down Expand Up @@ -82,16 +82,11 @@ class CrossValidatorSuite
assert(cvModel.hasSummary)
assert(cvModel.summary.params === lrParamMaps)
assert(cvModel.summary.trainingMetrics.count() === lrParamMaps.length)
val expectedSummary = spark.createDataFrame(Seq(
(0, 0.001),
(2, 0.001),
(0, 1.0),
(2, 1.0),
(0, 1000.0),
(2, 1000.0)
).map(t => (t._1.toString, t._2.toString))).toDF("maxIter", "regParam")
assert(cvModel.summary.trainingMetrics.select("maxIter", "regParam").collect().toSet
.equals(expectedSummary.collect().toSet))

val expected = lrParamMaps.zip(cvModel.avgMetrics).map { case (map, metric) =>
Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString))
}
assert(cvModel.summary.trainingMetrics.collect().toSet === expected.toSet)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a test for the exception being thrown if no summary?

test("cross validation with linear regression") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.StructType

class TrainValidationSplitSuite
Expand Down Expand Up @@ -75,16 +75,11 @@ class TrainValidationSplitSuite
assert(tvsModel.hasSummary)
assert(tvsModel.summary.params === lrParamMaps)
assert(tvsModel.summary.trainingMetrics.count() === lrParamMaps.length)
val expectedSummary = spark.createDataFrame(Seq(
(0, 0.001),
(2, 0.001),
(0, 1.0),
(2, 1.0),
(0, 1000.0),
(2, 1000.0)
).map(t => (t._1.toString, t._2.toString))).toDF("maxIter", "regParam")
assert(tvsModel.summary.trainingMetrics.select("maxIter", "regParam").collect().toSet
.equals(expectedSummary.collect().toSet))

val expected = lrParamMaps.zip(tvsModel.validationMetrics).map { case (map, metric) =>
Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString))
}
assert(tvsModel.summary.trainingMetrics.collect().toSet === expected.toSet)
}

test("train validation with linear regression") {
Expand Down