Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update algs
  • Loading branch information
zhengruifeng committed Sep 14, 2017
commit 3f11c67630dfc5402e49d7bf43d1ce9a31b400da
13 changes: 7 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
Expand Down Expand Up @@ -126,10 +128,10 @@ abstract class Predictor[
* Developers can override this for specific purpose.
*
* @param dataset Original training dataset
* @return Intermediate dataframe
* @return Intermediate training dataframe
*/
protected def preprocess(dataset: Dataset[_]): DataFrame = {
val cols = collection.mutable.ArrayBuffer[Column]()
val cols = ArrayBuffer[Column]()
cols.append(col($(featuresCol)))

// Cast LabelCol to DoubleType and keep the metadata.
Expand All @@ -146,7 +148,7 @@ abstract class Predictor[

val selected = dataset.select(cols: _*)

val cached = this match {
this match {
case p: HasHandlePersistence =>
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(p.handlePersistence)) {
Expand All @@ -156,11 +158,10 @@ abstract class Predictor[
"upstreams are also uncached.")
}
}
selected
case _ => selected
case _ =>
}

cached
selected
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ import org.apache.spark.util.VersionUtils
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth {
with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth
with HasHandlePersistence {

import org.apache.spark.ml.classification.LogisticRegression.supportedFamilyNames

Expand Down Expand Up @@ -431,6 +432,13 @@ class LogisticRegression @Since("1.2.0") (
@Since("2.2.0")
def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

private def assertBoundConstrainedOptimizationParamsValid(
numCoefficientSets: Int,
numFeatures: Int): Unit = {
Expand Down Expand Up @@ -484,23 +492,14 @@ class LogisticRegression @Since("1.2.0") (
}

override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
train(dataset, handlePersistence)
}

protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val instr = Instrumentation.create(this, instances)
val instr = Instrumentation.create(this, dataset)
instr.logParams(regParam, elasticNetParam, standardization, threshold,
maxIter, tol, fitIntercept)

Expand Down Expand Up @@ -878,8 +877,6 @@ class LogisticRegression @Since("1.2.0") (
}
}

if (handlePersistence) instances.unpersist()

val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.param.shared.{HasHandlePersistence, HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -56,7 +56,7 @@ private[ml] trait ClassifierTypeTrait {
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams
with ClassifierTypeTrait with HasWeightCol {
with ClassifierTypeTrait with HasWeightCol with HasHandlePersistence {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand All @@ -68,6 +68,13 @@ private[ml] trait OneVsRestParams extends PredictorParams

/** @group getParam */
def getClassifier: ClassifierType = $(classifier)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
}

private[ml] object OneVsRestParams extends ClassifierTypeTrait {
Expand Down Expand Up @@ -165,9 +172,13 @@ final class OneVsRestModel private[ml] (
val newDataset = dataset.withColumn(accColName, initUDF())

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) {
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

// update the accumulator column with the result of prediction of models
Expand All @@ -191,8 +202,8 @@ final class OneVsRestModel private[ml] (
updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName)
}

if (handlePersistence) {
newDataset.unpersist()
if (newDataset.storageLevel != StorageLevel.NONE) {
newDataset.unpersist(blocking = false)
}

// output the index of the classifier with highest confidence as prediction
Expand Down Expand Up @@ -359,9 +370,13 @@ final class OneVsRest @Since("1.4.0") (
}

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) {
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

val executionContext = getExecutionContext
Expand Down Expand Up @@ -392,8 +407,8 @@ final class OneVsRest @Since("1.4.0") (
.map(ThreadUtils.awaitResult(_, Duration.Inf)).toArray[ClassificationModel[_, _]]
instr.logNumFeatures(models.head.numFeatures)

if (handlePersistence) {
multiclassLabeled.unpersist()
if (multiclassLabeled.storageLevel != StorageLevel.NONE) {
multiclassLabeled.unpersist(blocking = false)
}

// extract label metadata from label column if present, or create a nominal attribute
Expand Down
25 changes: 18 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol {
with HasSeed with HasPredictionCol with HasTol with HasHandlePersistence {

/**
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
Expand Down Expand Up @@ -300,20 +300,31 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
transformSchema(dataset.schema, logging = true)

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}

if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

val instr = Instrumentation.create(this, instances)
val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
val algo = new MLlibKMeans()
.setK($(k))
Expand All @@ -329,8 +340,8 @@ class KMeans @Since("1.5.0") (

model.setSummary(Some(summary))
instr.logSuccess(model)
if (handlePersistence) {
instances.unpersist()
if (instances.getStorageLevel != StorageLevel.NONE) {
instances.unpersist(blocking = false)
}
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ import org.apache.spark.storage.StorageLevel
*/
private[regression] trait AFTSurvivalRegressionParams extends Params
with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
with HasTol with HasFitIntercept with HasAggregationDepth with HasHandlePersistence
with Logging {

/**
* Param for censor column name.
Expand Down Expand Up @@ -197,6 +198,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

/**
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
* and put it in an RDD with strong types.
Expand All @@ -213,8 +221,15 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
transformSchema(dataset.schema, logging = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

val featuresSummarizer = {
val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
Expand Down Expand Up @@ -273,7 +288,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
}

bcFeaturesStd.destroy(blocking = false)
if (handlePersistence) instances.unpersist()
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist(blocking = false)

val rawCoefficients = parameters.slice(2, parameters.length)
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,29 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
@Since("2.0.0")
def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)

override protected def preprocess(dataset: Dataset[_]): DataFrame = {
val cols = collection.mutable.ArrayBuffer[Column]()
cols.append(col($(featuresCol)))

// Cast LabelCol to DoubleType and keep the metadata.
val labelMeta = dataset.schema($(labelCol)).metadata
cols.append(col($(labelCol)).cast(DoubleType).as($(labelCol), labelMeta))

// Cast WeightCol to DoubleType and keep the metadata.
if (isDefined(weightCol) && $(weightCol).nonEmpty) {
val weightMeta = dataset.schema($(weightCol)).metadata
cols.append(col($(weightCol)).cast(DoubleType).as($(weightCol), weightMeta))
}

// Cast OffsetCol to DoubleType and keep the metadata.
if (isDefined(offsetCol) && $(offsetCol).nonEmpty) {
val offsetMeta = dataset.schema($(offsetCol)).metadata
cols.append(col($(offsetCol)).cast(DoubleType).as($(offsetCol), offsetMeta))
}

dataset.select(cols: _*)
}

override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val familyAndLink = FamilyAndLink(this)

Expand All @@ -393,7 +416,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
"set to false. To fit a model with 0 features, fitIntercept must be set to true." )

val w = if (!hasWeightCol) lit(1.0) else col($(weightCol))
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol))

val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) {
// TODO: Make standardizeFeatures and standardizeLabel configurable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ import org.apache.spark.storage.StorageLevel
* Params for isotonic regression.
*/
private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol
with HasLabelCol with HasPredictionCol with HasWeightCol with Logging {
with HasLabelCol with HasPredictionCol with HasWeightCol with HasHandlePersistence
with Logging {

/**
* Param for whether the output sequence should be isotonic/increasing (true) or
Expand Down Expand Up @@ -157,6 +158,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("1.5.0")
def setFeatureIndex(value: Int): this.type = set(featureIndex, value)

/**
* Sets whether to handle data persistence.
* @group setParam
*/
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

@Since("1.5.0")
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)

Expand All @@ -165,8 +173,14 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
transformSchema(dataset.schema, logging = true)
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractWeightedLabeledPoints(dataset)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
if (dataset.storageLevel == StorageLevel.NONE) {
if ($(handlePersistence)) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
} else {
logWarning("The input dataset is uncached, which may hurt performance if its " +
"upstreams are also uncached.")
}
}

val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic)
Expand All @@ -175,7 +189,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic))
val oldModel = isotonicRegression.run(instances)

if (handlePersistence) instances.unpersist()
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist(blocking = false)

val model = copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this))
instr.logSuccess(model)
Expand Down
Loading