-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22060][ML] Fix CrossValidator/TrainValidationSplit param persist/load bug #19278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| package org.apache.spark.ml.tuning | ||
|
|
||
| import java.io.IOException | ||
| import java.util.{List => JList} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
@@ -212,14 +213,12 @@ object CrossValidator extends MLReadable[CrossValidator] { | |
|
|
||
| val (metadata, estimator, evaluator, estimatorParamMaps) = | ||
| ValidatorParams.loadImpl(path, sc, className) | ||
| val numFolds = (metadata.params \ "numFolds").extract[Int] | ||
| val seed = (metadata.params \ "seed").extract[Long] | ||
| new CrossValidator(metadata.uid) | ||
| val cv = new CrossValidator(metadata.uid) | ||
| .setEstimator(estimator) | ||
| .setEvaluator(evaluator) | ||
| .setEstimatorParamMaps(estimatorParamMaps) | ||
| .setNumFolds(numFolds) | ||
| .setSeed(seed) | ||
| DefaultParamsReader.getAndSetParams(cv, metadata, skipParams = List("estimatorParamMaps")) | ||
|
||
| cv | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -303,16 +302,16 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | |
| val (metadata, estimator, evaluator, estimatorParamMaps) = | ||
| ValidatorParams.loadImpl(path, sc, className) | ||
| val numFolds = (metadata.params \ "numFolds").extract[Int] | ||
|
||
| val seed = (metadata.params \ "seed").extract[Long] | ||
| val bestModelPath = new Path(path, "bestModel").toString | ||
| val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) | ||
| val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray | ||
|
|
||
| val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) | ||
| model.set(model.estimator, estimator) | ||
| .set(model.evaluator, evaluator) | ||
| .set(model.estimatorParamMaps, estimatorParamMaps) | ||
| .set(model.numFolds, numFolds) | ||
| .set(model.seed, seed) | ||
| DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps")) | ||
|
||
| model | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -399,14 +399,17 @@ private[ml] object DefaultParamsReader { | |
| * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. | ||
|
||
| * TODO: Move to [[Metadata]] method | ||
| */ | ||
| def getAndSetParams(instance: Params, metadata: Metadata): Unit = { | ||
| def getAndSetParams(instance: Params, metadata: Metadata, | ||
|
||
| skipParams: List[String] = null): Unit = { | ||
|
||
| implicit val format = DefaultFormats | ||
| metadata.params match { | ||
| case JObject(pairs) => | ||
| pairs.foreach { case (paramName, jsonValue) => | ||
| val param = instance.getParam(paramName) | ||
| val value = param.jsonDecode(compact(render(jsonValue))) | ||
| instance.set(param, value) | ||
| if (skipParams == null || !skipParams.contains(paramName)) { | ||
| val param = instance.getParam(paramName) | ||
| val value = param.jsonDecode(compact(render(jsonValue))) | ||
| instance.set(param, value) | ||
| } | ||
| } | ||
| case _ => | ||
| throw new IllegalArgumentException( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,12 +159,15 @@ class CrossValidatorSuite | |
| .setEvaluator(evaluator) | ||
| .setNumFolds(20) | ||
| .setEstimatorParamMaps(paramMaps) | ||
| .setSeed(42L) | ||
| .setParallelism(2) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the test for the model too please
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
|
|
||
| val cv2 = testDefaultReadWrite(cv, testParams = false) | ||
|
|
||
| assert(cv.uid === cv2.uid) | ||
| assert(cv.getNumFolds === cv2.getNumFolds) | ||
| assert(cv.getSeed === cv2.getSeed) | ||
| assert(cv.getParallelism === cv2.getParallelism) | ||
|
|
||
| assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) | ||
| val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio | |
| import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput | ||
| import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} | ||
| import org.apache.spark.ml.linalg.Vectors | ||
| import org.apache.spark.ml.param.{ParamMap} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.param.shared.HasInputCol | ||
| import org.apache.spark.ml.regression.LinearRegression | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | ||
|
|
@@ -160,11 +160,13 @@ class TrainValidationSplitSuite | |
| .setTrainRatio(0.5) | ||
| .setEstimatorParamMaps(paramMaps) | ||
| .setSeed(42L) | ||
| .setParallelism(2) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you update the test for the Model too please?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. The model do not own
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, you're right, thanks |
||
|
|
||
| val tvs2 = testDefaultReadWrite(tvs, testParams = false) | ||
|
|
||
| assert(tvs.getTrainRatio === tvs2.getTrainRatio) | ||
| assert(tvs.getSeed === tvs2.getSeed) | ||
| assert(tvs.getParallelism === tvs2.getParallelism) | ||
|
|
||
| ValidatorParamsSuiteHelpers | ||
| .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This exception is unused & can be removed.