diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 08b0cb9b8f6a..59ddf1702fe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -24,9 +26,10 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.storage.StorageLevel /** * (private[ml]) Trait for parameters for prediction (regression and classification). @@ -99,23 +102,43 @@ abstract class Predictor[ // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) + val cols = ArrayBuffer[Column]() + cols.append(col($(featuresCol))) + // Cast LabelCol to DoubleType and keep the metadata. val labelMeta = dataset.schema($(labelCol)).metadata - val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + cols.append(col($(labelCol)).cast(DoubleType).as($(labelCol), labelMeta)) // Cast WeightCol to DoubleType and keep the metadata. - val casted = this match { - case p: HasWeightCol => - if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { - val weightMeta = dataset.schema($(p.weightCol)).metadata - labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta) - } else { - labelCasted + this match { + case p: HasWeightCol if isDefined(p.weightCol) && $(p.weightCol).nonEmpty => + val weightMeta = dataset.schema($(p.weightCol)).metadata + cols.append(col($(p.weightCol)).cast(DoubleType).as($(p.weightCol), weightMeta)) + case _ => + } + + val selected = dataset.select(cols: _*) + + this match { + case p: HasHandlePersistence => + if (dataset.storageLevel == StorageLevel.NONE) { + if ($(p.handlePersistence)) { + selected.persist(StorageLevel.MEMORY_AND_DISK) + } else { + logWarning("The input dataset is uncached, which may hurt performance if its " + + "upstreams are also uncached.") + } } - case _ => labelCasted + case _ => + } + + val model = copyValues(train(selected).setParent(this)) + + if (selected.storageLevel != StorageLevel.NONE) { + selected.unpersist(blocking = false) } - copyValues(train(casted).setParent(this)) + model } override def copy(extra: ParamMap): Learner diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index cbc8f4a2d8c2..1df1fecb6f6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -51,7 +51,8 @@ import org.apache.spark.util.VersionUtils */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth { + with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth + with HasHandlePersistence { import org.apache.spark.ml.classification.LogisticRegression.supportedFamilyNames @@ -431,6 +432,13 @@ class LogisticRegression @Since("1.2.0") ( @Since("2.2.0") def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) + /** + * Sets whether to handle data persistence. + * @group setParam + */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + private def assertBoundConstrainedOptimizationParamsValid( numCoefficientSets: Int, numFeatures: Int): Unit = { @@ -484,13 +492,6 @@ class LogisticRegression @Since("1.2.0") ( } override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - train(dataset, handlePersistence) - } - - protected[spark] def train( - dataset: Dataset[_], - handlePersistence: Boolean): LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -498,9 +499,7 @@ class LogisticRegression @Since("1.2.0") ( Instance(label, weight, features) } - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -878,8 +877,6 @@ class LogisticRegression @Since("1.2.0") ( } } - if (handlePersistence) instances.unpersist() - val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, isMultinomial)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 92a7742f6c86..6667e96d5ea3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} -import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} +import org.apache.spark.ml.param.shared.{HasHandlePersistence, HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -56,7 +56,7 @@ private[ml] trait ClassifierTypeTrait { * Params for [[OneVsRest]]. */ private[ml] trait OneVsRestParams extends PredictorParams - with ClassifierTypeTrait with HasWeightCol { + with ClassifierTypeTrait with HasWeightCol with HasHandlePersistence { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -68,6 +68,13 @@ private[ml] trait OneVsRestParams extends PredictorParams /** @group getParam */ def getClassifier: ClassifierType = $(classifier) + + /** + * Sets whether to handle data persistence. + * @group setParam + */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) } private[ml] object OneVsRestParams extends ClassifierTypeTrait { @@ -165,9 +172,13 @@ final class OneVsRestModel private[ml] ( val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) { - newDataset.persist(StorageLevel.MEMORY_AND_DISK) + if (dataset.storageLevel == StorageLevel.NONE) { + if ($(handlePersistence)) { + newDataset.persist(StorageLevel.MEMORY_AND_DISK) + } else { + logWarning("The input dataset is uncached, which may hurt performance if its " + + "upstreams are also uncached.") + } } // update the accumulator column with the result of prediction of models @@ -191,8 +202,8 @@ final class OneVsRestModel private[ml] ( updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName) } - if (handlePersistence) { - newDataset.unpersist() + if (newDataset.storageLevel != StorageLevel.NONE) { + newDataset.unpersist(blocking = false) } // output the index of the classifier with highest confidence as prediction @@ -359,9 +370,13 @@ final class OneVsRest @Since("1.4.0") ( } // persist if underlying dataset is not persistent. - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) { - multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) + if (dataset.storageLevel == StorageLevel.NONE) { + if ($(handlePersistence)) { + multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) + } else { + logWarning("The input dataset is uncached, which may hurt performance if its " + + "upstreams are also uncached.") + } } val executionContext = getExecutionContext @@ -392,8 +407,8 @@ final class OneVsRest @Since("1.4.0") ( .map(ThreadUtils.awaitResult(_, Duration.Inf)).toArray[ClassificationModel[_, _]] instr.logNumFeatures(models.head.numFeatures) - if (handlePersistence) { - multiclassLabeled.unpersist() + if (multiclassLabeled.storageLevel != StorageLevel.NONE) { + multiclassLabeled.unpersist(blocking = false) } // extract label metadata from label column if present, or create a nominal attribute diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f2af7fe082b4..4967e1945c1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasHandlePersistence { /** * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than @@ -300,20 +300,31 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + /** + * Sets whether to handle data persistence. + * @group setParam + */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { transformSchema(dataset.schema, logging = true) - val handlePersistence = dataset.storageLevel == StorageLevel.NONE val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } - if (handlePersistence) { - instances.persist(StorageLevel.MEMORY_AND_DISK) + if (dataset.storageLevel == StorageLevel.NONE) { + if ($(handlePersistence)) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } else { + logWarning("The input dataset is uncached, which may hurt performance if its " + + "upstreams are also uncached.") + } } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) @@ -329,8 +340,8 @@ class KMeans @Since("1.5.0") ( model.setSummary(Some(summary)) instr.logSuccess(model) - if (handlePersistence) { - instances.unpersist() + if (instances.getStorageLevel != StorageLevel.NONE) { + instances.unpersist(blocking = false) } model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1860fe836174..b86c6a8cd55f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -82,7 +82,9 @@ private[shared] object SharedParamsCodeGen { "all instance weights as 1.0"), ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), - isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) + isValid = "ParamValidators.gtEq(2)", isExpertParam = true), + ParamDesc[Boolean]("handlePersistence", "whether to handle data persistence. If true, " + + "we will cache unpersisted input data before fitting estimator on it", Some("true"))) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 6061d9ca0a08..67ed8b7c2ad3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -402,4 +402,21 @@ private[ml] trait HasAggregationDepth extends Params { /** @group expertGetParam */ final def getAggregationDepth: Int = $(aggregationDepth) } + +/** + * Trait for shared param handlePersistence (default: true). + */ +private[ml] trait HasHandlePersistence extends Params { + + /** + * Param for whether to handle data persistence. If true, we will cache unpersisted input data before fitting estimator on it. + * @group param + */ + final val handlePersistence: BooleanParam = new BooleanParam(this, "handlePersistence", "whether to handle data persistence. If true, we will cache unpersisted input data before fitting estimator on it") + + setDefault(handlePersistence, true) + + /** @group getParam */ + final def getHandlePersistence: Boolean = $(handlePersistence) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4b46c3831d75..abc59c911a15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -46,7 +46,8 @@ import org.apache.spark.storage.StorageLevel */ private[regression] trait AFTSurvivalRegressionParams extends Params with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter - with HasTol with HasFitIntercept with HasAggregationDepth with Logging { + with HasTol with HasFitIntercept with HasAggregationDepth with HasHandlePersistence + with Logging { /** * Param for censor column name. @@ -197,6 +198,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Sets whether to handle data persistence. + * @group setParam + */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + /** * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. @@ -213,8 +221,15 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + if (dataset.storageLevel == StorageLevel.NONE) { + if ($(handlePersistence)) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } else { + logWarning("The input dataset is uncached, which may hurt performance if its " + + "upstreams are also uncached.") + } + } val featuresSummarizer = { val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features) @@ -273,7 +288,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } bcFeaturesStd.destroy(blocking = false) - if (handlePersistence) instances.unpersist() + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist(blocking = false) val rawCoefficients = parameters.slice(2, parameters.length) var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 917a4d238d46..524c7bcbb8f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -373,6 +373,32 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val @Since("2.0.0") def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) + override def fit(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { + transformSchema(dataset.schema, logging = true) + + val cols = collection.mutable.ArrayBuffer[Column]() + cols.append(col($(featuresCol))) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + cols.append(col($(labelCol)).cast(DoubleType).as($(labelCol), labelMeta)) + + // Cast WeightCol to DoubleType and keep the metadata. + if (isDefined(weightCol) && $(weightCol).nonEmpty) { + val weightMeta = dataset.schema($(weightCol)).metadata + cols.append(col($(weightCol)).cast(DoubleType).as($(weightCol), weightMeta)) + } + + // Cast OffsetCol to DoubleType and keep the metadata. + if (isDefined(offsetCol) && $(offsetCol).nonEmpty) { + val offsetMeta = dataset.schema($(offsetCol)).metadata + cols.append(col($(offsetCol)).cast(DoubleType).as($(offsetCol), offsetMeta)) + } + + val selected = dataset.select(cols: _*) + train(selected) + } + override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { val familyAndLink = FamilyAndLink(this) @@ -393,7 +419,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) val w = if (!hasWeightCol) lit(1.0) else col($(weightCol)) - val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) + val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)) val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { // TODO: Make standardizeFeatures and standardizeLabel configurable. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8faab52ea474..c4e0948b48f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -39,7 +39,8 @@ import org.apache.spark.storage.StorageLevel * Params for isotonic regression. */ private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol - with HasLabelCol with HasPredictionCol with HasWeightCol with Logging { + with HasLabelCol with HasPredictionCol with HasWeightCol with HasHandlePersistence + with Logging { /** * Param for whether the output sequence should be isotonic/increasing (true) or @@ -157,6 +158,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("1.5.0") def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + /** + * Sets whether to handle data persistence. + * @group setParam + */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) @@ -165,8 +173,14 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri transformSchema(dataset.schema, logging = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + if (dataset.storageLevel == StorageLevel.NONE) { + if ($(handlePersistence)) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } else { + logWarning("The input dataset is uncached, which may hurt performance if its " + + "upstreams are also uncached.") + } + } val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic) @@ -175,7 +189,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) val oldModel = isotonicRegression.run(instances) - if (handlePersistence) instances.unpersist() + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist(blocking = false) val model = copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) instr.logSuccess(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b2a968118d1a..739081cc66b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -53,7 +53,7 @@ import org.apache.spark.storage.StorageLevel private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver - with HasAggregationDepth { + with HasAggregationDepth with HasHandlePersistence { import LinearRegression._ @@ -208,6 +208,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Sets whether to handle data persistence. + * @group setParam + */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size @@ -232,6 +239,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = $(elasticNetParam), $(standardization), true, solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) + + if (dataset.storageLevel != StorageLevel.NONE) dataset.unpersist(blocking = false) + val model = optimizer.fit(instances) // When it is trained by WeightedLeastSquares, training summary does not // attach returned model. @@ -251,9 +261,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel } - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (featuresSummarizer, ySummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), instance: Instance) => @@ -285,7 +292,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String s"zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } - if (handlePersistence) instances.unpersist() val coefficients = Vectors.sparse(numFeatures, Seq.empty) val intercept = yMean @@ -422,8 +428,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String 0.0 } - if (handlePersistence) instances.unpersist() - val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 4b650000736e..b079f25db666 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -444,13 +444,13 @@ class LogisticRegressionWithLBFGS lr.setFitIntercept(addIntercept) lr.setMaxIter(optimizer.getNumIterations()) lr.setTol(optimizer.getConvergenceTol()) + // Determine if we should cache the DF + lr.setHandlePersistence(input.getStorageLevel == StorageLevel.NONE) // Convert our input into a DataFrame val spark = SparkSession.builder().sparkContext(input.context).getOrCreate() val df = spark.createDataFrame(input.map(_.asML)) - // Determine if we should cache the DF - val handlePersistence = input.getStorageLevel == StorageLevel.NONE // Train our model - val mlLogisticRegressionModel = lr.train(df, handlePersistence) + val mlLogisticRegressionModel = lr.fit(df) // convert the model val weights = Vectors.dense(mlLogisticRegressionModel.coefficients.toArray) createModel(weights, mlLogisticRegressionModel.intercept) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd299e074535..a2706ae442cb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -68,7 +68,12 @@ object MimaExcludes { // [SPARK-14280] Support Scala 2.12 ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"), + + // [SPARK-18608] Add Param HasHandlePersistence + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.handlePersistence"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.getHandlePersistence"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.org$apache$spark$ml$param$shared$HasHandlePersistence$_setter_$handlePersistence_=") ) // Exclude rules for 2.2.x