Skip to content
Closed
Prev Previous commit
address minor issues
  • Loading branch information
WeichenXu123 committed Nov 14, 2017
commit 7e997da44157a9807de9c8fe8e7d2e5b66b6bfb1
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,10 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean]
val persistSubModels = (metadata.metadata \ "persistSubModels")
.extractOrElse[Boolean](false)

val subModels: Option[Array[Array[Model[_]]]] = if (shouldPersistSubModels) {
val subModels: Option[Array[Array[Model[_]]]] = if (persistSubModels) {
val subModelsPath = new Path(path, "subModels")
val _subModels = Array.fill(numFolds)(Array.fill[Model[_]](
estimatorParamMaps.length)(null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
}

@Since("2.0.0")
override def write: TrainValidationSplit.TrainValidationSplitWriter = {
new TrainValidationSplit.TrainValidationSplitWriter(this)
}
override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this)
}

@Since("2.0.0")
Expand Down Expand Up @@ -311,7 +309,9 @@ class TrainValidationSplitModel private[ml] (
}

@Since("2.0.0")
override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
override def write: TrainValidationSplitModel.TrainValidationSplitModelWriter = {
new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
}
}

@Since("2.0.0")
Expand Down Expand Up @@ -339,6 +339,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
* If subModels are not available, then setting "persistSubModels" to "true" will cause
* an exception.
*/
@Since("2.3.0")
final class TrainValidationSplitModelWriter private[tuning] (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since annotation

instance: TrainValidationSplitModel) extends MLWriter {

Expand Down Expand Up @@ -385,9 +386,10 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean]
val persistSubModels = (metadata.metadata \ "persistSubModels")
.extractOrElse[Boolean](false)

val subModels: Option[Array[Model[_]]] = if (shouldPersistSubModels) {
val subModels: Option[Array[Model[_]]] = if (persistSubModels) {
val subModelsPath = new Path(path, "subModels")
val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null)
for (paramIndex <- 0 until estimatorParamMaps.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ class CrossValidatorSuite
val eval = new BinaryClassificationEvaluator
val numFolds = 3
val subPath = new File(tempDir, "testCrossValidatorSubModels")
val persistSubModelsPath = new File(subPath, "subModels").toString

val cv = new CrossValidator()
.setEstimator(lr)
Expand Down