Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init pr
  • Loading branch information
WeichenXu123 committed Sep 20, 2017
commit 86abf481949d61cbcc726bcadfa91d12686846d6
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tuning

import java.io.IOException
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exception is unused & can be removed.

import java.util.{List => JList}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -212,14 +213,12 @@ object CrossValidator extends MLReadable[CrossValidator] {

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val seed = (metadata.params \ "seed").extract[Long]
new CrossValidator(metadata.uid)
val cv = new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setNumFolds(numFolds)
.setSeed(seed)
DefaultParamsReader.getAndSetParams(cv, metadata, skipParams = List("estimatorParamMaps"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you also need to skip estimator and evaluator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Because estimator and evaluator isn't included in metadata. You can check the saveImpl.

cv
}
}
}
Expand Down Expand Up @@ -303,16 +302,16 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numFolds is no longer needed

val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray

val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.numFolds, numFolds)
.set(model.seed, seed)
DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also skip estimator and evaluator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Because estimator and evaluator isn't included in metadata. You can check the saveImpl.

model
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tuning

import java.io.IOException
import java.util.{List => JList}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -207,14 +208,12 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
val seed = (metadata.params \ "seed").extract[Long]
new TrainValidationSplit(metadata.uid)
val tvs = new TrainValidationSplit(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setTrainRatio(trainRatio)
.setSeed(seed)
DefaultParamsReader.getAndSetParams(tvs, metadata, skipParams = List("estimatorParamMaps"))
tvs
}
}
}
Expand Down Expand Up @@ -295,17 +294,16 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray

val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.trainRatio, trainRatio)
.set(model.seed, seed)
DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps"))
model
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,14 @@ private[ml] object ValidatorParams {
}.toSeq
))

val validatorSpecificParams = instance match {
case cv: CrossValidatorParams =>
List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
case tvs: TrainValidationSplitParams =>
List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
case _ =>
// This should not happen.
throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " +
instance.getClass.getCanonicalName)
}

val jsonParams = validatorSpecificParams ++ List(
"estimatorParamMaps" -> parse(estimatorParamMapsJson),
"seed" -> parse(instance.seed.jsonEncode(instance.getSeed)))
val params = instance.extractParamMap().toSeq
val skipParams = List("estimator", "evaluator", "estimatorParamMaps")
val jsonParams = render(params
.filter { case ParamPair(p, v) => !skipParams.contains(p.name)}
.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson))
)

DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))

Expand Down
11 changes: 7 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,17 @@ private[ml] object DefaultParamsReader {
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring to state that params included in skipParams aren't set.

* TODO: Move to [[Metadata]] method
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
def getAndSetParams(instance: Params, metadata: Metadata,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix scala style: 1 arg per line for multiline declarations

skipParams: List[String] = null): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use an Option[List[String]] that defaults to None instead of a List[String] that defaults to null?

implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
if (skipParams == null || !skipParams.contains(paramName)) {
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
}
case _ =>
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,15 @@ class CrossValidatorSuite
.setEvaluator(evaluator)
.setNumFolds(20)
.setEstimatorParamMaps(paramMaps)
.setSeed(42L)
.setParallelism(2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the test for the model too please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.


val cv2 = testDefaultReadWrite(cv, testParams = false)

assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
assert(cv.getSeed === cv2.getSeed)
assert(cv.getParallelism === cv2.getParallelism)

assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
Expand Down Expand Up @@ -160,11 +160,13 @@ class TrainValidationSplitSuite
.setTrainRatio(0.5)
.setEstimatorParamMaps(paramMaps)
.setSeed(42L)
.setParallelism(2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the test for the Model too please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The model do not own parallel parameter. This was discussed before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right, thanks


val tvs2 = testDefaultReadWrite(tvs, testParams = false)

assert(tvs.getTrainRatio === tvs2.getTrainRatio)
assert(tvs.getSeed === tvs2.getSeed)
assert(tvs.getParallelism === tvs2.getParallelism)

ValidatorParamsSuiteHelpers
.compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
Expand Down