Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d1e22d5
tuning summary
YY-OnCall Dec 3, 2016
a7cfa63
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Dec 5, 2016
ad73c12
add ut
YY-OnCall Dec 5, 2016
425a419
add comments
YY-OnCall Dec 5, 2016
bd18c00
get default spark session
YY-OnCall Dec 13, 2016
2a0af1d
resolve merge conflict
YY-OnCall Feb 22, 2017
1a594d0
merge conflict
YY-OnCall Jul 5, 2017
0a698fe
support cross validation
YY-OnCall Jul 5, 2017
bd459b1
update version
YY-OnCall Jul 5, 2017
4e3e19c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 24, 2017
c0bc81a
improve unit test
YY-OnCall Jul 24, 2017
bbf3f9f
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Aug 8, 2017
b6a7c53
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Aug 9, 2017
72aea62
remove TuningSummary
YY-OnCall Aug 9, 2017
91da358
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Sep 10, 2017
297091f
update for pipeline
YY-OnCall Sep 11, 2017
670467a
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Nov 19, 2017
36b1dd5
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Dec 29, 2017
5844e0c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jan 1, 2018
8c829b5
merge conflict
YY-OnCall Jul 23, 2018
4aaf8e5
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 23, 2018
ceaad1c
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 27, 2018
41c4c12
Merge remote-tracking branch 'upstream/master' into tuningsummary
YY-OnCall Jul 29, 2018
4aef3aa
remove sort add comments
YY-OnCall Jul 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
tuning summary
  • Loading branch information
YY-OnCall committed Dec 3, 2016
commit d1e22d58f5ecb5972f2ea528dc18d1230d678424
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ package org.apache.spark.ml.tuning
import java.util.{List => JList}

import scala.collection.JavaConverters._

import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.json4s.DefaultFormats

import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
Expand Down Expand Up @@ -127,7 +126,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
val model = copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
val summary = new TuningSummary(bestModel.transform(dataset), epm, metrics, bestIndex)
model.setSummary(Some(summary))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, the tuning summary will not be saved? Since it's a small dataframe, perhaps we should consider saving it with the model? (Can do that in a later PR however)

Copy link
Contributor Author

@hhbyyh hhbyyh Aug 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to just save the tuning summary in the model, perhaps we can just discard the TuningSummary, and add a tuningSummary: DataFrame field/function in the models. Sounds good?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other obvious things that might go into the summary in future, that would make a TuningSummary class a better fit?

Future support for say, multiple metrics, could simply extend the dataframe columns so that is ok. But is there anything else you can think of?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be something like detailed training log and training time for each model. But I'm thinking the current Summary pattern does have some room for improvement (e.g., save/load and API), it makes me feel bad when I have to duplicate the code like
def hasSummary: Boolean = trainingSummary.nonEmpty. Thus saving it to the models sounds like a good idea to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest implementation does not need to save the extra dataframe. Since basically the dataframe can be generated from $(estimatorParamMaps) and avgMetrics.

model
}

@Since("1.4.0")
Expand Down Expand Up @@ -234,6 +236,29 @@ class CrossValidatorModel private[ml] (

@Since("1.6.0")
override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)

private var trainingSummary: Option[TuningSummary] = None

private[tuning] def setSummary(summary: Option[TuningSummary]): this.type = {
this.trainingSummary = summary
this
}

/**
* Return true if there exists summary of model.
*/
@Since("2.0.0")
def hasSummary: Boolean = trainingSummary.nonEmpty

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.0.0")
def summary: TuningSummary = trainingSummary.getOrElse {
throw new SparkException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}
}

@Since("1.6.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.language.existentials
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
Expand Down Expand Up @@ -123,7 +124,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best train validation split metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
val model = copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
val summary = new TuningSummary(bestModel.transform(dataset), epm, metrics, bestIndex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems wasteful to do bestModel.transform(dataset) just to get access to the sqlContext. Is it really necessary?

Copy link
Contributor Author

@hhbyyh hhbyyh Dec 13, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed that's not necessary. I just replaced it with SparkSession.builder().getOrCreate(). Is there a better way to get the default contexts? Thanks

model.setSummary(Some(summary))
model
}

@Since("1.5.0")
Expand Down Expand Up @@ -226,6 +230,29 @@ class TrainValidationSplitModel private[ml] (

@Since("2.0.0")
override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)

private var trainingSummary: Option[TuningSummary] = None

private[tuning] def setSummary(summary: Option[TuningSummary]): this.type = {
this.trainingSummary = summary
this
}

/**
* Return true if there exists summary of model.
*/
@Since("2.0.0")
def hasSummary: Boolean = trainingSummary.nonEmpty

/**
* Gets summary of model on training set. An exception is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably rather be "summary of model performance on the validation set"?

* thrown if `trainingSummary == None`.
*/
@Since("2.0.0")
def summary: TuningSummary = trainingSummary.getOrElse {
throw new SparkException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}
Copy link
Contributor Author

@hhbyyh hhbyyh Dec 5, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking we should add a new trait hasSummary to wrap the summary-related code. I can create another jira if that's reasonable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in #17654

}

@Since("2.0.0")
Expand Down Expand Up @@ -275,3 +302,4 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.tuning

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

/**
* :: Experimental ::
* Summary of grid search tuning.
*
* @param params estimator param maps
* @param metrics Corresponding evaluation metrics for the param maps
*/
@Since("2.2.0")
@Experimental
class TuningSummary private[tuning](
@transient val predictions: DataFrame,
val params: Array[ParamMap],
val metrics: Array[Double],
val bestIndex: Int) {

def trainingMetrics: DataFrame = {
require(params.nonEmpty, "estimator param maps should not be empty")
require(params.length == metrics.length, "estimator param maps numner should match metrics")
val sqlContext = predictions.sqlContext
val sc = sqlContext.sparkContext
val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++ Seq("metrics")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"metrics" is a bit generic. Perhaps it's better (and more user-friendly) to make this be something like metric_name metric so that it's obvious what metric was being optimized for? such as ROC metric or AUC metric or MSE metric? etc

val schema = new StructType(fields.map(name => StructField(name, StringType)).toArray)

val rows = sc.parallelize(params.zip(metrics)).map { case (param, metric) =>
val values = param.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString)
Row.fromSeq(values)
}
sqlContext.createDataFrame(rows, schema)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ class TrainValidationSplitSuite
.addGrid(lr.maxIter, Array(0, 10))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new TrainValidationSplit()
val tv = new TrainValidationSplit()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
.setSeed(42L)
val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(cv.getTrainRatio === 0.5)
val tvModel = tv.fit(dataset)
val parent = tvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(tv.getTrainRatio === 0.5)
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.validationMetrics.length === lrParamMaps.length)
assert(tvModel.validationMetrics.length === lrParamMaps.length)
assert(tvModel.summary.params === lrParamMaps)
}

test("train validation with linear regression") {
Expand Down