Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ private[shared] object SharedParamsCodeGen {
"all instance weights as 1.0"),
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
ParamDesc[Boolean]("collectSubModels", "whether to collect sub models when tuning fitting",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: reword "whether to collect sub models when tuning fitting" --> "whether to collect a list of sub-models trained during tuning"

Some("false"), isExpertParam = true),
ParamDesc[String]("persistSubModelsPath", "The path to persist sub models when " +
"tuning fitting", Some("\"\""), isExpertParam = true)
)

val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,4 +402,38 @@ private[ml] trait HasAggregationDepth extends Params {
/** @group expertGetParam */
final def getAggregationDepth: Int = $(aggregationDepth)
}

/**
* Trait for shared param collectSubModels (default: false).
*/
private[ml] trait HasCollectSubModels extends Params {

/**
* Param for whether to collect sub models when tuning fitting.
* @group expertParam
*/
final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect sub models when tuning fitting")

setDefault(collectSubModels, false)

/** @group expertGetParam */
final def getCollectSubModels: Boolean = $(collectSubModels)
}

/**
* Trait for shared param persistSubModelsPath (default: "").
*/
private[ml] trait HasPersistSubModelsPath extends Params {

/**
* Param for The path to persist sub models when tuning fitting.
* @group expertParam
*/
final val persistSubModelsPath: Param[String] = new Param[String](this, "persistSubModelsPath", "The path to persist sub models when tuning fitting")

setDefault(persistSubModelsPath, "")

/** @group expertGetParam */
final def getPersistSubModelsPath: String = $(persistSubModelsPath)
}
// scalastyle:on
129 changes: 110 additions & 19 deletions mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tuning

import java.io.IOException
import java.util.{List => JList}

import scala.collection.JavaConverters._
Expand All @@ -31,7 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasParallelism
import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism, HasPersistSubModelsPath}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{DataFrame, Dataset}
Expand Down Expand Up @@ -67,7 +68,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
@Since("1.2.0")
class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
extends Estimator[CrossValidatorModel]
with CrossValidatorParams with HasParallelism with MLWritable with Logging {
with CrossValidatorParams with HasParallelism with HasCollectSubModels
with HasPersistSubModelsPath with MLWritable with Logging {

@Since("1.2.0")
def this() = this(Identifiable.randomUID("cv"))
Expand Down Expand Up @@ -101,6 +103,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.3.0")
def setParallelism(value: Int): this.type = set(parallelism, value)

/** @group expertSetParam */
@Since("2.3.0")
def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value)

/** @group expertSetParam */
@Since("2.3.0")
def setPersistSubModelsPath(value: String): this.type = set(persistSubModelsPath, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
Expand All @@ -117,6 +127,13 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
instr.logParams(numFolds, seed, parallelism)
logTuningParams(instr)

val collectSubModelsParam = $(collectSubModels)
val persistSubModelsPathParam = $(persistSubModelsPath)

var subModels: Array[Array[Model[_]]] = if (collectSubModelsParam) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps use an Option[Array[Model[_]]] instead of setting subModels to null?

Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))
} else null

// Compute metrics for each model over each split
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
Expand All @@ -125,10 +142,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
logDebug(s"Train split $splitIndex with multiple sets of parameters.")

// Fit models in a Future for training in parallel
val modelFutures = epm.map { paramMap =>
val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Model[_]] {
val model = est.fit(trainingDataset, paramMap)
model.asInstanceOf[Model[_]]
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]

if (collectSubModelsParam) {
subModels(splitIndex)(paramIndex) = model
}
if (persistSubModelsPathParam.nonEmpty) {
val modelPath = new Path(new Path(persistSubModelsPathParam, splitIndex.toString),
paramIndex.toString).toString
model.asInstanceOf[MLWritable].save(modelPath)
}
model
} (executionContext)
}

Expand Down Expand Up @@ -160,7 +186,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
copyValues(new CrossValidatorModel(uid, bestModel, metrics, subModels).setParent(this))
}

@Since("1.4.0")
Expand Down Expand Up @@ -212,14 +238,12 @@ object CrossValidator extends MLReadable[CrossValidator] {

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val seed = (metadata.params \ "seed").extract[Long]
new CrossValidator(metadata.uid)
val cv = new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setNumFolds(numFolds)
.setSeed(seed)
DefaultParamsReader.getAndSetParams(cv, metadata, skipParams = List("estimatorParamMaps"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use getAndSetParams instead of setting all params manually. This simplify code, and it can keep read/write compatibility.

cv
}
}
}
Expand All @@ -237,12 +261,17 @@ object CrossValidator extends MLReadable[CrossValidator] {
class CrossValidatorModel private[ml] (
@Since("1.4.0") override val uid: String,
@Since("1.2.0") val bestModel: Model[_],
@Since("1.5.0") val avgMetrics: Array[Double])
@Since("1.5.0") val avgMetrics: Array[Double],
@Since("2.3.0") val subModels: Array[Array[Model[_]]])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {

/** A Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = {
this(uid, bestModel, avgMetrics.asScala.toArray)
this(uid, bestModel, avgMetrics.asScala.toArray, null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See earlier suggestion, use an Option set to None instead of setting the Array to null

}

private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: Array[Double]) = {
this(uid, bestModel, avgMetrics, null)
}

@Since("2.0.0")
Expand All @@ -261,17 +290,40 @@ class CrossValidatorModel private[ml] (
val copied = new CrossValidatorModel(
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
avgMetrics.clone())
avgMetrics.clone(),
CrossValidatorModel.copySubModels(subModels))
copyValues(copied, extra).setParent(parent)
}

@Since("1.6.0")
override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)

@Since("2.3.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String, persistSubModels: Boolean): Unit = {
write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter]
.persistSubModels(persistSubModels).save(path)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add this method because the CrossValidatorModelWriter is private. User cannot use it. But I don't know whether there is better solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think users can still access CrossValidatorModelWriter through CrossValidatorModel.write, so the save method is unnecessary.

The private[CrossValidatorModel] annotation on the CrossValidatorModelWriter constructor only means that users can't create instances of the class e.g. via new CrossValidatorModel.CrossValidatorModelWriter(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried model.write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter] but cannot pass complier, it is inaccessible.
Do you have some other ways ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion: Another way I think is adding an interface def option(key: String, value: String) into Writer. cc @jkbradley

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the last suggestion of adding def option(key: String, value: String) to mimic the SQL datasource API.

}

@Since("1.6.0")
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {

private[CrossValidatorModel] def copySubModels(subModels: Array[Array[Model[_]]]) = {
var copiedSubModels: Array[Array[Model[_]]] = null
if (subModels != null) {
val numFolds = subModels.length
val numParamMaps = subModels(0).length
copiedSubModels = Array.fill(numFolds)(Array.fill[Model[_]](numParamMaps)(null))
for (i <- 0 until numFolds) {
for (j <- 0 until numParamMaps) {
copiedSubModels(i)(j) = subModels(i)(j).copy(ParamMap.empty).asInstanceOf[Model[_]]
}
}
}
copiedSubModels
}

@Since("1.6.0")
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader

Expand All @@ -283,12 +335,35 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {

ValidatorParams.validateParams(instance)

protected var shouldPersistSubModels: Boolean = false

/**
* Set option for persist sub models.
*/
@Since("2.3.0")
def persistSubModels(persist: Boolean): this.type = {
shouldPersistSubModels = persist
this
}

override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~
("shouldPersistSubModels" -> shouldPersistSubModels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have 1 name for this argument: "persistSubModels"

ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
if (shouldPersistSubModels) {
require(instance.subModels != null, "Cannot get sub models to persist.")
val subModelsPath = new Path(path, "subModels")
for (splitIndex <- 0 until instance.getNumFolds) {
val splitPath = new Path(subModelsPath, splitIndex.toString)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about naming this with the string "fold":
splitIndex.toString --> "fold" + splitIndex.toString?

for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
val modelPath = new Path(splitPath, paramIndex.toString).toString
instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
}
}
}
}
}

Expand All @@ -303,16 +378,32 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean]

val subModels: Array[Array[Model[_]]] = if (shouldPersistSubModels) {
val subModelsPath = new Path(path, "subModels")
val _subModels = Array.fill(numFolds)(Array.fill[Model[_]](
estimatorParamMaps.length)(null))
for (splitIndex <- 0 until numFolds) {
val splitPath = new Path(subModelsPath, splitIndex.toString)
for (paramIndex <- 0 until estimatorParamMaps.length) {
val modelPath = new Path(splitPath, paramIndex.toString).toString
_subModels(splitIndex)(paramIndex) =
DefaultParamsReader.loadParamsInstance(modelPath, sc)
}
}
_subModels
} else null

val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics, subModels)
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.numFolds, numFolds)
.set(model.seed, seed)
DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps"))
model
}
}
}
Loading