Skip to content

Commit 0a698fe

Browse files
committed
support cross validation
1 parent 1a594d0 commit 0a698fe

File tree

6 files changed

+94
-12
lines changed

6 files changed

+94
-12
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ object ModelSelectionViaCrossValidationExample {
112112
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
113113
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
114114
}
115+
cvModel.summary.trainingMetrics.show()
115116
// $example off$
116117

117118
spark.stop()

examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ object ModelSelectionViaTrainValidationSplitExample {
7474
model.transform(test)
7575
.select("features", "label", "prediction")
7676
.show()
77+
model.summary.trainingMetrics.show()
7778
// $example off$
7879

7980
spark.stop()

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import com.github.fommil.netlib.F2jBLAS
2525
import org.apache.hadoop.fs.Path
2626
import org.json4s.DefaultFormats
2727

28+
import org.apache.spark.SparkException
2829
import org.apache.spark.annotation.Since
2930
import org.apache.spark.internal.Logging
3031
import org.apache.spark.ml._
@@ -133,7 +134,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
133134
logInfo(s"Best cross-validation metric: $bestMetric.")
134135
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
135136
instr.logSuccess(bestModel)
136-
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
137+
val model = new CrossValidatorModel(uid, bestModel, metrics).setParent(this)
138+
val summary = new TuningSummary(epm, metrics, bestIndex)
139+
model.setSummary(Some(summary))
140+
copyValues(model)
137141
}
138142

139143
@Since("1.4.0")
@@ -229,6 +233,29 @@ class CrossValidatorModel private[ml] (
229233
bestModel.transformSchema(schema)
230234
}
231235

236+
private var trainingSummary: Option[TuningSummary] = None
237+
238+
private[tuning] def setSummary(summary: Option[TuningSummary]): this.type = {
239+
this.trainingSummary = summary
240+
this
241+
}
242+
243+
/**
244+
* Return true if there exists summary of model.
245+
*/
246+
@Since("2.3.0")
247+
def hasSummary: Boolean = trainingSummary.nonEmpty
248+
249+
/**
250+
* Gets summary of model on training set. An exception is
251+
* thrown if `trainingSummary == None`.
252+
*/
253+
@Since("2.3.0")
254+
def summary: TuningSummary = trainingSummary.getOrElse {
255+
throw new SparkException(
256+
s"No training summary available for the ${this.getClass.getSimpleName}")
257+
}
258+
232259
@Since("1.4.0")
233260
override def copy(extra: ParamMap): CrossValidatorModel = {
234261
val copied = new CrossValidatorModel(

mllib/src/main/scala/org/apache/spark/ml/tuning/TuningSummary.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
1817
package org.apache.spark.ml.tuning
1918

2019
import org.apache.spark.annotation.{Experimental, Since}
@@ -29,15 +28,16 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType}
2928
* @param params estimator param maps
3029
* @param metrics Corresponding evaluation metrics for the param maps
3130
*/
32-
@Since("2.2.0")
31+
@Since("2.3.0")
3332
@Experimental
34-
class TuningSummary private[tuning](
33+
private[tuning] class TuningSummary private[tuning](
3534
val params: Array[ParamMap],
3635
val metrics: Array[Double],
3736
val bestIndex: Int) {
3837

3938
/**
40-
* Summary of grid search tuning in the format of DataFrame.
39+
* Summary of grid search tuning in the format of DataFrame. Each row contains one candidate
40+
* paramMap and its corresponding metrics.
4141
*/
4242
def trainingMetrics: DataFrame = {
4343
require(params.nonEmpty, "estimator param maps should not be empty")

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,34 @@ class CrossValidatorSuite
6666
assert(cvModel.avgMetrics.length === lrParamMaps.length)
6767
}
6868

69+
test("cross validation with tuning summary") {
70+
val lr = new LogisticRegression
71+
val lrParamMaps = new ParamGridBuilder()
72+
.addGrid(lr.regParam, Array(0.001, 1.0, 1000.0))
73+
.addGrid(lr.maxIter, Array(0, 2))
74+
.build()
75+
val eval = new BinaryClassificationEvaluator
76+
val cv = new CrossValidator()
77+
.setEstimator(lr)
78+
.setEstimatorParamMaps(lrParamMaps)
79+
.setEvaluator(eval)
80+
.setNumFolds(3)
81+
val cvModel = cv.fit(dataset)
82+
assert(cvModel.hasSummary)
83+
assert(cvModel.summary.params === lrParamMaps)
84+
assert(cvModel.summary.trainingMetrics.count() === lrParamMaps.length)
85+
val expectedSummary = spark.createDataFrame(Seq(
86+
(0, 0.001),
87+
(2, 0.001),
88+
(0, 1.0),
89+
(2, 1.0),
90+
(0, 1000.0),
91+
(2, 1000.0)
92+
).map(t => (t._1.toString, t._2.toString))).toDF("maxIter", "regParam")
93+
assert(cvModel.summary.trainingMetrics.select("maxIter", "regParam").collect().toSet
94+
.equals(expectedSummary.collect().toSet))
95+
}
96+
6997
test("cross validation with linear regression") {
7098
val dataset = sc.parallelize(
7199
LinearDataGenerator.generateLinearInput(

mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,40 @@ class TrainValidationSplitSuite
5151
.setEvaluator(eval)
5252
.setTrainRatio(0.5)
5353
.setSeed(42L)
54-
val tvModel = tvs.fit(dataset)
55-
val parent = tvModel.bestModel.parent.asInstanceOf[LogisticRegression]
54+
val tvsModel = tvs.fit(dataset)
55+
val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression]
5656
assert(tvs.getTrainRatio === 0.5)
5757
assert(parent.getRegParam === 0.001)
5858
assert(parent.getMaxIter === 10)
59-
assert(tvModel.validationMetrics.length === lrParamMaps.length)
60-
assert(tvModel.summary.params === lrParamMaps)
61-
assert(tvModel.summary.trainingMetrics.count() === lrParamMaps.length)
62-
assert(tvModel.summary.trainingMetrics.columns === Array("maxIter", "regParam", "metrics"))
59+
assert(tvsModel.validationMetrics.length === lrParamMaps.length)
60+
}
61+
62+
test("train validation split with tuning summary") {
63+
val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
64+
val lr = new LogisticRegression
65+
val lrParamMaps = new ParamGridBuilder()
66+
.addGrid(lr.regParam, Array(0.001, 1.0, 1000.0))
67+
.addGrid(lr.maxIter, Array(0, 2))
68+
.build()
69+
val eval = new BinaryClassificationEvaluator
70+
val tvs = new TrainValidationSplit()
71+
.setEstimator(lr)
72+
.setEstimatorParamMaps(lrParamMaps)
73+
.setEvaluator(eval)
74+
val tvsModel = tvs.fit(dataset)
75+
assert(tvsModel.hasSummary)
76+
assert(tvsModel.summary.params === lrParamMaps)
77+
assert(tvsModel.summary.trainingMetrics.count() === lrParamMaps.length)
78+
val expectedSummary = spark.createDataFrame(Seq(
79+
(0, 0.001),
80+
(2, 0.001),
81+
(0, 1.0),
82+
(2, 1.0),
83+
(0, 1000.0),
84+
(2, 1000.0)
85+
).map(t => (t._1.toString, t._2.toString))).toDF("maxIter", "regParam")
86+
assert(tvsModel.summary.trainingMetrics.select("maxIter", "regParam").collect().toSet
87+
.equals(expectedSummary.collect().toSet))
6388
}
6489

6590
test("train validation with linear regression") {
@@ -89,7 +114,7 @@ class TrainValidationSplitSuite
89114
assert(parent.getMaxIter === 10)
90115
assert(tvsModel.validationMetrics.length === lrParamMaps.length)
91116

92-
eval.setMetricName("r2")
117+
eval.setMetricName("r2")
93118
val tvsModel2 = tvs.fit(dataset)
94119
val parent2 = tvsModel2.bestModel.parent.asInstanceOf[LinearRegression]
95120
assert(parent2.getRegParam === 0.001)

0 commit comments

Comments
 (0)