Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

package org.apache.spark.ml

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel

/**
* (private[ml]) Trait for parameters for prediction (regression and classification).
Expand Down Expand Up @@ -99,23 +102,43 @@ abstract class Predictor[
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)

val cols = ArrayBuffer[Column]()
cols.append(col($(featuresCol)))

// Cast LabelCol to DoubleType and keep the metadata.
val labelMeta = dataset.schema($(labelCol)).metadata
val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
cols.append(col($(labelCol)).cast(DoubleType).as($(labelCol), labelMeta))

// Cast WeightCol to DoubleType and keep the metadata.
val casted = this match {
case p: HasWeightCol =>
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
val weightMeta = dataset.schema($(p.weightCol)).metadata
labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta)
} else {
labelCasted
this match {
case p: HasWeightCol if isDefined(p.weightCol) && $(p.weightCol).nonEmpty =>
val weightMeta = dataset.schema($(p.weightCol)).metadata
cols.append(col($(p.weightCol)).cast(DoubleType).as($(p.weightCol), weightMeta))
case _ =>
}

val selected = dataset.select(cols: _*)

this match {
case p: HasHandlePersistence =>
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(p.handlePersistence)) {
selected.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}
case _ => labelCasted
case _ =>
}

val model = copyValues(train(selected).setParent(this))

if (selected.storageLevel != StorageLevel.NONE) {
selected.unpersist(blocking = false)
}

copyValues(train(casted).setParent(this))
model
}

override def copy(extra: ParamMap): Learner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ import org.apache.spark.util.VersionUtils
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth {
with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth
with HasHandlePersistence {

import org.apache.spark.ml.classification.LogisticRegression.supportedFamilyNames

Expand Down Expand Up @@ -431,6 +432,13 @@ class LogisticRegression @Since("1.2.0") (
@Since("2.2.0")
def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

private def assertBoundConstrainedOptimizationParamsValid(
numCoefficientSets: Int,
numFeatures: Int): Unit = {
Expand Down Expand Up @@ -484,23 +492,14 @@ class LogisticRegression @Since("1.2.0") (
}

override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
train(dataset, handlePersistence)
}

protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val instr = Instrumentation.create(this, instances)
val instr = Instrumentation.create(this, dataset)
instr.logParams(regParam, elasticNetParam, standardization, threshold,
maxIter, tol, fitIntercept)

Expand Down Expand Up @@ -878,8 +877,6 @@ class LogisticRegression @Since("1.2.0") (
}
}

if (handlePersistence) instances.unpersist()

val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.param.shared.{HasHandlePersistence, HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -56,7 +56,7 @@ private[ml] trait ClassifierTypeTrait {
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams
with ClassifierTypeTrait with HasWeightCol {
with ClassifierTypeTrait with HasWeightCol with HasHandlePersistence {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand All @@ -68,6 +68,13 @@ private[ml] trait OneVsRestParams extends PredictorParams

/** @group getParam */
def getClassifier: ClassifierType = $(classifier)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
}

private[ml] object OneVsRestParams extends ClassifierTypeTrait {
Expand Down Expand Up @@ -165,9 +172,13 @@ final class OneVsRestModel private[ml] (
val newDataset = dataset.withColumn(accColName, initUDF())

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) {
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

// update the accumulator column with the result of prediction of models
Expand All @@ -191,8 +202,8 @@ final class OneVsRestModel private[ml] (
updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName)
}

if (handlePersistence) {
newDataset.unpersist()
if (newDataset.storageLevel != StorageLevel.NONE) {
newDataset.unpersist(blocking = false)
}

// output the index of the classifier with highest confidence as prediction
Expand Down Expand Up @@ -359,9 +370,13 @@ final class OneVsRest @Since("1.4.0") (
}

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) {
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

val executionContext = getExecutionContext
Expand Down Expand Up @@ -392,8 +407,8 @@ final class OneVsRest @Since("1.4.0") (
.map(ThreadUtils.awaitResult(_, Duration.Inf)).toArray[ClassificationModel[_, _]]
instr.logNumFeatures(models.head.numFeatures)

if (handlePersistence) {
multiclassLabeled.unpersist()
if (multiclassLabeled.storageLevel != StorageLevel.NONE) {
multiclassLabeled.unpersist(blocking = false)
}

// extract label metadata from label column if present, or create a nominal attribute
Expand Down
25 changes: 18 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol {
with HasSeed with HasPredictionCol with HasTol with HasHandlePersistence {

/**
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
Expand Down Expand Up @@ -300,20 +300,31 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
transformSchema(dataset.schema, logging = true)

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}

if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

val instr = Instrumentation.create(this, instances)
val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
val algo = new MLlibKMeans()
.setK($(k))
Expand All @@ -329,8 +340,8 @@ class KMeans @Since("1.5.0") (

model.setSummary(Some(summary))
instr.logSuccess(model)
if (handlePersistence) {
instances.unpersist()
if (instances.getStorageLevel != StorageLevel.NONE) {
instances.unpersist(blocking = false)
}
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ private[shared] object SharedParamsCodeGen {
"all instance weights as 1.0"),
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
ParamDesc[Boolean]("handlePersistence", "whether to handle data persistence. If true, " +
"we will cache unpersisted input data before fitting estimator on it", Some("true")))

val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,4 +402,21 @@ private[ml] trait HasAggregationDepth extends Params {
/** @group expertGetParam */
final def getAggregationDepth: Int = $(aggregationDepth)
}

/**
* Trait for shared param handlePersistence (default: true).
*/
private[ml] trait HasHandlePersistence extends Params {

/**
* Param for whether to handle data persistence. If true, we will cache unpersisted input data before fitting estimator on it.
* @group param
*/
final val handlePersistence: BooleanParam = new BooleanParam(this, "handlePersistence", "whether to handle data persistence. If true, we will cache unpersisted input data before fitting estimator on it")

setDefault(handlePersistence, true)

/** @group getParam */
final def getHandlePersistence: Boolean = $(handlePersistence)
}
// scalastyle:on
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ import org.apache.spark.storage.StorageLevel
*/
private[regression] trait AFTSurvivalRegressionParams extends Params
with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
with HasTol with HasFitIntercept with HasAggregationDepth with HasHandlePersistence
with Logging {

/**
* Param for censor column name.
Expand Down Expand Up @@ -197,6 +198,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

/**
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
* and put it in an RDD with strong types.
Expand All @@ -213,8 +221,15 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
transformSchema(dataset.schema, logging = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

val featuresSummarizer = {
val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
Expand Down Expand Up @@ -273,7 +288,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
}

bcFeaturesStd.destroy(blocking = false)
if (handlePersistence) instances.unpersist()
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist(blocking = false)

val rawCoefficients = parameters.slice(2, parameters.length)
var i = 0
Expand Down
Loading