diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2012d6ca8b5ea..158b0d5a63ac5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -73,11 +73,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.2.0") - def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + def setEstimator(value: Estimator[_]): this.type = setEstimators(Array(value)) /** @group setParam */ @Since("1.2.0") - def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + def setEstimatorParamMaps(value: Array[ParamMap]): this.type = + setEstimatorsParamMaps(Array(value)) + + /** @group setParam */ + def setEstimators(value: Array[Estimator[_]]): this.type = set(estimators, value) + + /** @group setParam */ + def setEstimatorsParamMaps(value: Array[Array[ParamMap]]): this.type = + set(estimatorsParamMaps, value) /** @group setParam */ @Since("1.2.0") @@ -96,15 +104,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val schema = dataset.schema transformSchema(schema, logging = true) val sparkSession = dataset.sparkSession - val est = $(estimator) + val ests = $(estimators) val eval = $(evaluator) - val epm = $(estimatorParamMaps) - val numModels = epm.length - val metrics = new Array[Double](epm.length) + val epms = $(estimatorsParamMaps).flatten + val metrics = new Array[Double](getModelCount) + val modelToEstIndex = getModelToEstIndex val instr = Instrumentation.create(this, dataset) instr.logParams(numFolds, seed) - logTuningParams(instr) + ests.indices.foreach(logTuningParams(instr, _)) val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => @@ -112,26 +120,30 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val validationDataset = sparkSession.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") - val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + val models = ests.zip($(estimatorsParamMaps)) + .flatMap(estEpm => estEpm._1.fit(trainingDataset, estEpm._2).asInstanceOf[Seq[Model[_]]]) trainingDataset.unpersist() var i = 0 - while (i < numModels) { + while (i < getModelCount) { // TODO: duplicate evaluator to take extra params from input - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) - logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + val metric = eval.evaluate(models(i).transform(validationDataset, epms(i))) + logDebug(s"Got metric $metric for model trained with " + + s"${ests(modelToEstIndex(i))} and parameters ${epms(i)}.") metrics(i) += metric i += 1 } validationDataset.unpersist() } - f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) + f2jBLAS.dscal(getModelCount, 1.0 / $(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best estimator:\n${ests(modelToEstIndex(bestIndex))}") + logInfo(s"Best set of parameters:\n${epms(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") - val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + val bestModel = ests(modelToEstIndex(bestIndex)) + .fit(dataset, epms(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } @@ -142,8 +154,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") override def copy(extra: ParamMap): CrossValidator = { val copied = defaultCopy(extra).asInstanceOf[CrossValidator] - if (copied.isDefined(estimator)) { - copied.setEstimator(copied.getEstimator.copy(extra)) + if (copied.isDefined(estimators)) { + copied.setEstimators(copied.getEstimators.map(_.copy(extra))) } if (copied.isDefined(evaluator)) { copied.setEvaluator(copied.getEvaluator.copy(extra)) @@ -183,14 +195,14 @@ object CrossValidator extends MLReadable[CrossValidator] { override def load(path: String): CrossValidator = { implicit val format = DefaultFormats - val (metadata, estimator, evaluator, estimatorParamMaps) = + val (metadata, estimators, evaluator, estimatorsParamMaps) = ValidatorParams.loadImpl(path, sc, className) val numFolds = (metadata.params \ "numFolds").extract[Int] val seed = (metadata.params \ "seed").extract[Long] new CrossValidator(metadata.uid) - .setEstimator(estimator) + .setEstimators(estimators) .setEvaluator(evaluator) - .setEstimatorParamMaps(estimatorParamMaps) + .setEstimatorsParamMaps(estimatorsParamMaps) .setNumFolds(numFolds) .setSeed(seed) } @@ -273,7 +285,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { override def load(path: String): CrossValidatorModel = { implicit val format = DefaultFormats - val (metadata, estimator, evaluator, estimatorParamMaps) = + val (metadata, estimators, evaluator, estimatorsParamMaps) = ValidatorParams.loadImpl(path, sc, className) val numFolds = (metadata.params \ "numFolds").extract[Int] val seed = (metadata.params \ "seed").extract[Long] @@ -281,9 +293,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) - model.set(model.estimator, estimator) + model.set(model.estimators, estimators) .set(model.evaluator, evaluator) - .set(model.estimatorParamMaps, estimatorParamMaps) + .set(model.estimatorsParamMaps, estimatorsParamMaps) .set(model.numFolds, numFolds) .set(model.seed, seed) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index db7c9d13d301a..a7ba688732e3c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -69,16 +69,24 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St /** @group setParam */ @Since("1.5.0") - def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + def setEstimator(value: Estimator[_]): this.type = setEstimators(Array(value)) /** @group setParam */ @Since("1.5.0") - def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + def setEstimatorParamMaps(value: Array[ParamMap]): this.type = + setEstimatorsParamMaps(Array(value)) /** @group setParam */ @Since("1.5.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + /** @group setParam */ + def setEstimators(value: Array[Estimator[_]]): this.type = set(estimators, value) + + /** @group setParam */ + def setEstimatorsParamMaps(value: Array[Array[ParamMap]]): this.type = + set(estimatorsParamMaps, value) + /** @group setParam */ @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) @@ -91,15 +99,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) - val est = $(estimator) + val ests = $(estimators) val eval = $(evaluator) - val epm = $(estimatorParamMaps) - val numModels = epm.length - val metrics = new Array[Double](epm.length) + val epms = $(estimatorsParamMaps).flatten + val metrics = new Array[Double](getModelCount) + val modelToEstIndex = getModelToEstIndex val instr = Instrumentation.create(this, dataset) instr.logParams(trainRatio, seed) - logTuningParams(instr) + ests.indices.foreach(logTuningParams(instr, _)) val Array(trainingDataset, validationDataset) = dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) @@ -108,13 +116,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // multi-model training logDebug(s"Train split with multiple sets of parameters.") - val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + val models = ests.zip($(estimatorsParamMaps)) + .flatMap(estEpm => estEpm._1.fit(trainingDataset, estEpm._2).asInstanceOf[Seq[Model[_]]]) trainingDataset.unpersist() var i = 0 - while (i < numModels) { + while (i < getModelCount) { // TODO: duplicate evaluator to take extra params from input - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) - logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + val metric = eval.evaluate(models(i).transform(validationDataset, epms(i))) + logDebug(s"Got metric $metric for model trained with " + + s"${ests(modelToEstIndex(i))} and parameters ${epms(i)}.") metrics(i) += metric i += 1 } @@ -124,9 +134,11 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best estimator:\n${ests(modelToEstIndex(bestIndex))}") + logInfo(s"Best set of parameters:\n${epms(bestIndex)}") logInfo(s"Best train validation split metric: $bestMetric.") - val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + val bestModel = ests(modelToEstIndex(bestIndex)) + .fit(dataset, epms(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) } @@ -137,8 +149,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplit = { val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit] - if (copied.isDefined(estimator)) { - copied.setEstimator(copied.getEstimator.copy(extra)) + if (copied.isDefined(estimators)) { + copied.setEstimators(copied.getEstimators.map(_.copy(extra))) } if (copied.isDefined(evaluator)) { copied.setEvaluator(copied.getEvaluator.copy(extra)) @@ -176,14 +188,14 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { override def load(path: String): TrainValidationSplit = { implicit val format = DefaultFormats - val (metadata, estimator, evaluator, estimatorParamMaps) = + val (metadata, estimators, evaluator, estimatorsParamMaps) = ValidatorParams.loadImpl(path, sc, className) val trainRatio = (metadata.params \ "trainRatio").extract[Double] val seed = (metadata.params \ "seed").extract[Long] new TrainValidationSplit(metadata.uid) - .setEstimator(estimator) + .setEstimators(estimators) .setEvaluator(evaluator) - .setEstimatorParamMaps(estimatorParamMaps) + .setEstimatorsParamMaps(estimatorsParamMaps) .setTrainRatio(trainRatio) .setSeed(seed) } @@ -264,7 +276,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { override def load(path: String): TrainValidationSplitModel = { implicit val format = DefaultFormats - val (metadata, estimator, evaluator, estimatorParamMaps) = + val (metadata, estimators, evaluator, estimatorsParamMaps) = ValidatorParams.loadImpl(path, sc, className) val trainRatio = (metadata.params \ "trainRatio").extract[Double] val seed = (metadata.params \ "seed").extract[Long] @@ -272,9 +284,9 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) - model.set(model.estimator, estimator) + model.set(model.estimators, estimators) .set(model.evaluator, evaluator) - .set(model.estimatorParamMaps, estimatorParamMaps) + .set(model.estimatorsParamMaps, estimatorsParamMaps) .set(model.trainRatio, trainRatio) .set(model.seed, seed) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index d55eb14d03456..88c5aa760c772 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -40,21 +40,28 @@ private[ml] trait ValidatorParams extends HasSeed with Params { * * @group param */ - val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + val estimators: Param[Array[Estimator[_]]] = + new Param(this, "estimators", "estimators for selection") /** @group getParam */ - def getEstimator: Estimator[_] = $(estimator) + def getEstimator: Estimator[_] = $(estimators).head + + /** @group getParam */ + def getEstimators: Array[Estimator[_]] = $(estimators) /** * param for estimator param maps * * @group param */ - val estimatorParamMaps: Param[Array[ParamMap]] = - new Param(this, "estimatorParamMaps", "param maps for the estimator") + val estimatorsParamMaps: Param[Array[Array[ParamMap]]] = + new Param(this, "estimatorsParamMaps", "param maps for the estimators") + + /** @group getParam */ + def getEstimatorParamMaps: Array[ParamMap] = $(estimatorsParamMaps).head /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) + def getEstimatorsParamMaps: Array[Array[ParamMap]] = $(estimatorsParamMaps) /** * param for the evaluator used to select hyper-parameters that maximize the validated metric @@ -67,30 +74,49 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) + protected def getModelToEstIndex: Array[Int] = + $(estimatorsParamMaps).map(_.length).zipWithIndex.flatMap(idx => List.fill(idx._1)(idx._2)) + + protected def getModelCount: Int = $(estimatorsParamMaps).flatten.length + protected def transformSchemaImpl(schema: StructType): StructType = { - require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps") - val firstEstimatorParamMap = $(estimatorParamMaps).head - val est = $(estimator) - for (paramMap <- $(estimatorParamMaps).tail) { - est.copy(paramMap).transformSchema(schema) + def transformSchemaIdx(schema: StructType, idx: Int): StructType = { + require($(estimatorsParamMaps)(idx).nonEmpty, + s"Validator requires non-empty estimatorParamMaps") + val firstEstimatorParamMap = $(estimatorsParamMaps)(idx).head + val est = $(estimators)(idx) + for (paramMap <- $(estimatorsParamMaps)(idx).tail) { + est.copy(paramMap).transformSchema(schema) + } + est.copy(firstEstimatorParamMap).transformSchema(schema) } - est.copy(firstEstimatorParamMap).transformSchema(schema) + + require($(estimators).length == $(estimatorsParamMaps).length, + s"Number of estimators must match number of estimatorParamMaps") + $(estimators).indices.map(idx => transformSchemaIdx(schema, idx)) + .foldLeft(schema)((acc, curr) => acc.merge(curr)) } /** * Instrumentation logging for tuning params including the inner estimator and evaluator info. */ - protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { - instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName) + protected def logTuningParams(instrumentation: Instrumentation[_], idx: Int = 0): Unit = { + instrumentation.logNamedValue("estimator", $(estimators).length match { + case len if len > idx => $(estimators)(idx).getClass.getCanonicalName + case _ => $(estimators).getClass.getCanonicalName + }) instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) - instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) + instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorsParamMaps).length match { + case len if len > idx => $(estimatorsParamMaps)(idx).length + case _ => 0 + }) } } private[ml] object ValidatorParams { /** - * Check that [[ValidatorParams.evaluator]] and [[ValidatorParams.estimator]] are Writable. - * This does not check [[ValidatorParams.estimatorParamMaps]]. + * Check that [[ValidatorParams.evaluator]] and [[ValidatorParams.estimators]] are Writable. + * This does not check [[ValidatorParams.estimatorsParamMaps]]. */ def validateParams(instance: ValidatorParams): Unit = { def checkElement(elem: Params, name: String): Unit = elem match { @@ -101,15 +127,17 @@ private[ml] object ValidatorParams { s" Non-Writable $name: ${other.uid} of type ${other.getClass}") } checkElement(instance.getEvaluator, "evaluator") - checkElement(instance.getEstimator, "estimator") + instance.getEstimators.foreach(checkElement(_, "estimator")) // Check to make sure all Params apply to this estimator. Throw an error if any do not. // Extraneous Params would cause problems when loading the estimatorParamMaps. val uidToInstance: Map[String, Params] = MetaAlgorithmReadWrite.getUidMap(instance) - instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => - pMap.toSeq.foreach { case ParamPair(p, v) => - require(uidToInstance.contains(p.parent), s"ValidatorParams save requires all Params in" + - s" estimatorParamMaps to apply to this ValidatorParams, its Estimator, or its" + - s" Evaluator. An extraneous Param was found: $p") + instance.getEstimatorsParamMaps.foreach { + _.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"ValidatorParams save requires all Params in" + + s" estimatorParamMaps to apply to this ValidatorParams, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } } } } @@ -126,13 +154,17 @@ private[ml] object ValidatorParams { extraMetadata: Option[JObject] = None): Unit = { import org.json4s.JsonDSL._ - val estimatorParamMapsJson = compact(render( - instance.getEstimatorParamMaps.map { case paramMap => - paramMap.toSeq.map { case ParamPair(p, v) => - Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) - } - }.toSeq - )) + val estimatorsParamMapsJson = + instance.getEstimatorsParamMaps.zipWithIndex.map { case (paramMaps, idx) => + val json = compact(render( + paramMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + indexKeyName("estimatorParamMaps", idx) -> parse(json) + } val validatorSpecificParams = instance match { case cv: CrossValidatorParams => @@ -145,16 +177,19 @@ private[ml] object ValidatorParams { instance.getClass.getCanonicalName) } - val jsonParams = validatorSpecificParams ++ List( - "estimatorParamMaps" -> parse(estimatorParamMapsJson), + val jsonParams = validatorSpecificParams ++ estimatorsParamMapsJson ++ List( + "estimatorNumber" -> parse(instance.getEstimators.length.toString), "seed" -> parse(instance.seed.jsonEncode(instance.getSeed))) DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) val evaluatorPath = new Path(path, "evaluator").toString instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) - val estimatorPath = new Path(path, "estimator").toString - instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + + instance.getEstimators.zipWithIndex.foreach { case(est, idx) => + val estimatorPath = new Path(path, indexKeyName("estimator", idx)).toString + est.asInstanceOf[MLWritable].save(estimatorPath) + } } /** @@ -165,21 +200,28 @@ private[ml] object ValidatorParams { def loadImpl[M <: Model[M]]( path: String, sc: SparkContext, - expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = { + expectedClassName: String): + (Metadata, Array[Estimator[_]], Evaluator, Array[Array[ParamMap]]) = { val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) implicit val format = DefaultFormats val evaluatorPath = new Path(path, "evaluator").toString val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) - val estimatorPath = new Path(path, "estimator").toString - val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) - val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator) + // backwards compatible if missing, assumes 1 estimator + val estimatorNumber = (metadata.params \ "estimatorNumber").extractOrElse[String]("1").toInt + val estimators: Array[Estimator[_]] = (0 until estimatorNumber).map { idx => + val estimatorPath = new Path(path, indexKeyName("estimator", idx)).toString + DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + }.toArray + + val uidToParams = Map(evaluator.uid -> evaluator) ++ + estimators.flatMap(MetaAlgorithmReadWrite.getUidMap) - val estimatorParamMaps: Array[ParamMap] = - (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { - pMap => + val estimatorsParamsMaps: Array[Array[ParamMap]] = (0 until estimatorNumber).map { idx => + (metadata.params \ indexKeyName("estimatorParamMaps", idx)) + .extract[Seq[Seq[Map[String, String]]]].map { pMap => val paramPairs = pMap.map { case pInfo: Map[String, String] => val est = uidToParams(pInfo("parent")) val param = est.getParam(pInfo("name")) @@ -188,7 +230,13 @@ private[ml] object ValidatorParams { } ParamMap(paramPairs: _*) }.toArray + }.toArray + + (metadata, estimators, evaluator, estimatorsParamsMaps) + } - (metadata, estimator, evaluator, estimatorParamMaps) + private def indexKeyName(name: String, idx: Int): String = idx match { + case idx if idx == 0 => name // backwards compatible name for head + case idx => name + idx.toString } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 09bddcdb810bb..ed2436b8e6c7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -458,7 +458,7 @@ private[ml] object MetaAlgorithmReadWrite { val subStages: Array[Params] = instance match { case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] - case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case v: ValidatorParams => v.getEstimators ++ Array(v.getEvaluator) case ovr: OneVsRest => Array(ovr.getClassifier) case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models case rformModel: RFormulaModel => Array(rformModel.pipelineModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 7116265474f22..f0de41bdbafcb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.{Estimator, Model, Pipeline} -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, NaiveBayes} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.feature.{Binarizer, HashingTF, Tokenizer} import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol @@ -254,10 +254,10 @@ class CrossValidatorSuite .addGrid(lr.regParam, Array(0.1, 0.2)) .build() val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) - cv.set(cv.estimator, lr) + cv.set(cv.estimators, Array[org.apache.spark.ml.Estimator[_]](lr)) .set(cv.evaluator, evaluator) .set(cv.numFolds, 20) - .set(cv.estimatorParamMaps, paramMaps) + .set(cv.estimatorsParamMaps, Array[Array[org.apache.spark.ml.param.ParamMap]](paramMaps)) val cv2 = testDefaultReadWrite(cv, testParams = false) @@ -293,6 +293,70 @@ class CrossValidatorSuite } assert(cv.avgMetrics === cv2.avgMetrics) } + + test("cross validation with two pipelines: logistic regression and naive bayes") { + val training = spark.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) + )).toDF("id", "text", "label") + + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + + // Configure an ML pipeline using nb. + val nb = new NaiveBayes() + val pipeline1 = new Pipeline("p1").setStages(Array(tokenizer, hashingTF, nb)) + val paramGrid1 = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100)) + .build() + + // Configure an ML pipeline using lr. + val lr = new LogisticRegression().setMaxIter(10) + val pipeline2 = new Pipeline("p2").setStages(Array(tokenizer, hashingTF, lr)) + val paramGrid2 = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100)) + .build() + + // Configure an ML pipeline using nb bernoulli (4 stages) + val binarizer = new Binarizer() + .setInputCol(hashingTF.getOutputCol) + .setOutputCol("binary_features") + val nb2 = new NaiveBayes() + .setModelType("bernoulli") + .setFeaturesCol(binarizer.getOutputCol) + val pipeline3 = new Pipeline("p3").setStages(Array(tokenizer, hashingTF, binarizer, nb2)) + val paramGrid3 = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100)) + .build() + + // cross validate with both pipelines + val cv = new CrossValidator() + .setEstimators(Array(pipeline1, pipeline2, pipeline3)) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorsParamMaps(Array(paramGrid1, paramGrid2, paramGrid3)) + .setNumFolds(2) + + // Run cross-validation, and choose the best set of parameters. + val cvModel = cv.fit(training) + + assert(cvModel.bestModel.uid === "p2") + assert(cvModel.bestModel.asInstanceOf[PipelineModel].stages(1) + .asInstanceOf[HashingTF].getNumFeatures === 100) + } } object CrossValidatorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 4463a9b6e543a..82b6d1e6ed988 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -147,10 +147,10 @@ class TrainValidationSplitSuite .addGrid(lr.regParam, Array(0.1, 0.2)) .build() val tvs = new TrainValidationSplitModel("cvUid", lrModel, Array(0.3, 0.6)) - tvs.set(tvs.estimator, lr) + tvs.set(tvs.estimators, Array[org.apache.spark.ml.Estimator[_]](lr)) .set(tvs.evaluator, evaluator) .set(tvs.trainRatio, 0.5) - .set(tvs.estimatorParamMaps, paramMaps) + .set(tvs.estimatorsParamMaps, Array[Array[org.apache.spark.ml.param.ParamMap]](paramMaps)) .set(tvs.seed, 42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 8d8b5b86d5aa1..0b83362e9e00b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -375,7 +375,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be * thrown. */ - private[sql] def merge(that: StructType): StructType = + private[spark] def merge(that: StructType): StructType = StructType.merge(this, that).asInstanceOf[StructType] override private[spark] def asNullable: StructType = {