Skip to content
Closed
Prev Previous commit
Next Next commit
address comments from joseph
  • Loading branch information
WeichenXu123 committed Nov 7, 2017
commit 7bacfcac2e20552bb4557614ba477cf776bdf8af
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.tuning

import java.util.{List => JList}
import java.util.{List => JList, Locale}

import scala.collection.JavaConverters._
import scala.concurrent.Future
Expand Down Expand Up @@ -282,15 +282,14 @@ class CrossValidatorModel private[ml] (
/**
* @return submodels represented in two dimension array. The index of outer array is the
* fold index, and the index of inner array corresponds to the ordering of
* estimatorParamsMaps
*
* Note: If submodels not available, exception will be thrown. only when we set collectSubModels
* Param before fitting, submodels will be available.
* estimatorParamMaps
* @throws IllegalArgumentException if subModels are not available. To retrieve subModels,
* make sure to set collectSubModels to true before fitting.
*/
@Since("2.3.0")
def subModels: Array[Array[Model[_]]] = {
require(_subModels.isDefined, "submodels not available, set collectSubModels param before " +
"fitting will address this issue.")
require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " +
"to set collectSubModels to true before fitting.")
_subModels.get
}

Expand Down Expand Up @@ -342,22 +341,27 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
* Writer for CrossValidatorModel.
* @param instance CrossValidatorModel instance used to construct the writer
*
* Options:
* CrossValidatorModelWriter support an option "persistSubModels", available value is
* "true" or "false". If you set collectSubModels param before fitting, and then you can set
* the option "persistSubModels" to be "true" and the submodels will be persisted.
* The default value of "persistSubModels" will be "true", if you set collectSubModels
* param before fitting, but if you do not set collectSubModels param before fitting, setting
* "persistSubModels" will cause exception.
* CrossValidatorModelWriter supports an option "persistSubModels", with possible values
* "true" or "false". If you set the collectSubModels Param before fitting, then you can
* set "persistSubModels" to "true" in order to persist the subModels. By default,
* "persistSubModels" will be "true" when subModels are available and "false" otherwise.
* If subModels are not available, then setting "persistSubModels" to "true" will cause
* an exception.
*/
@Since("2.3.0")
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
final class CrossValidatorModelWriter private[tuning] (
instance: CrossValidatorModel) extends MLWriter {

ValidatorParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
val persistSubModels = optionMap.getOrElse("persistsubmodels",
if (instance.hasSubModels) "true" else "false").toBoolean
val persistSubModelsParam = optionMap.getOrElse("persistsubmodels",
if (instance.hasSubModels) "true" else "false")

require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)),
s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " +
"values are \"true\" or \"false\"")
val persistSubModels = persistSubModelsParam.toBoolean

import org.json4s.JsonDSL._
val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
* Note: If set this param, when you save the returned model, you can set an option
* "persistSubModels" to be "true" before saving, in order to save these submodels.
* You can check documents of
* {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter}
* {@link org.apache.spark.ml.tuning.TrainValidationSplitModel.TrainValidationSplitModelWriter}
* for more information.
*
* @group expertSetParam
Expand Down Expand Up @@ -275,15 +275,14 @@ class TrainValidationSplitModel private[ml] (

/**
* @return submodels represented in array. The index of array corresponds to the ordering of
* estimatorParamsMaps
*
* Note: If submodels not available, exception will be thrown. only when we set collectSubModels
* Param before fitting, submodels will be available.
* estimatorParamMaps
* @throws IllegalArgumentException if subModels are not available. To retrieve subModels,
* make sure to set collectSubModels to true before fitting.
*/
@Since("2.3.0")
def subModels: Array[Model[_]] = {
require(_subModels.isDefined, "submodels not available, set collectSubModels param before " +
"fitting will address this issue.")
require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " +
"to set collectSubModels to true before fitting.")
_subModels.get
}

Expand Down Expand Up @@ -333,21 +332,26 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
* Writer for TrainValidationSplitModel.
* @param instance TrainValidationSplitModel instance used to construct the writer
*
* Options:
* TrainValidationSplitModel support an option "persistSubModels", available value is
* "true" or "false". If you set collectSubModels param before fitting, and then you can set
* the option "persistSubModels" to be "true" and the submodels will be persisted.
* The default value of "persistSubModels" will be "true", if you set collectSubModels
* param before fitting, but if you do not set collectSubModels param before fitting, setting
* "persistSubModels" will cause exception.
* TrainValidationSplitModel supports an option "persistSubModels", with possible values
* "true" or "false". If you set the collectSubModels Param before fitting, then you can
* set "persistSubModels" to "true" in order to persist the subModels. By default,
* "persistSubModels" will be "true" when subModels are available and "false" otherwise.
* If subModels are not available, then setting "persistSubModels" to "true" will cause
* an exception.
*/
class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter {
final class TrainValidationSplitModelWriter private[tuning] (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since annotation

instance: TrainValidationSplitModel) extends MLWriter {

ValidatorParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
val persistSubModels = optionMap.getOrElse("persistsubmodels",
if (instance.hasSubModels) "true" else "false").toBoolean
val persistSubModelsParam = optionMap.getOrElse("persistsubmodels",
if (instance.hasSubModels) "true" else "false")

require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)),
s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " +
"values are \"true\" or \"false\"")
val persistSubModels = persistSubModelsParam.toBoolean

import org.json4s.JsonDSL._
val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toSeq) ~
Expand Down
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,13 @@ abstract class MLWriter extends BaseReadWrite with Logging {
protected def saveImpl(path: String): Unit

/**
* Map store extra options for this writer.
* Map to store extra options for this writer.
*/
protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()

/**
* `option()` handles extra options.
* Adds an option to the underlying MLWriter. See the documentation for the specific model's
* writer for possible options. The option name (key) is case-insensitive.
*/
@Since("2.3.0")
def option(key: String, value: String): this.type = {
Expand Down