Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
del pre/post
  • Loading branch information
zhengruifeng committed Sep 26, 2017
commit 18f9903707d029d40c5eb03dc8e856a6607ac723
57 changes: 15 additions & 42 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,6 @@ abstract class Predictor[
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)

val dataframe = preprocess(dataset)

val model = copyValues(train(dataframe).setParent(this))

postprocess(dataframe)

model
}

override def copy(extra: ParamMap): Learner

/**
* Train a model using the given dataset and parameters.
* Developers can implement this instead of `fit()` to avoid dealing with schema validation
* and copying parameters into the model.
*
* @param dataset Training dataset
* @return Fitted model
*/
protected def train(dataset: Dataset[_]): M

/**
* Pre-process the input dataset to an intermediate dataframe.
* Developers can override this for specific purpose.
*
* @param dataset Original training dataset
* @return Intermediate training dataframe
*/
protected def preprocess(dataset: Dataset[_]): DataFrame = {
val cols = ArrayBuffer[Column]()
cols.append(col($(featuresCol)))

Expand Down Expand Up @@ -161,24 +132,26 @@ abstract class Predictor[
case _ =>
}

selected
val model = copyValues(train(selected).setParent(this))

if (selected.storageLevel != StorageLevel.NONE) {
selected.unpersist(blocking = false)
}

model
}

override def copy(extra: ParamMap): Learner

/**
* Post-process the intermediate dataframe.
* Developers can override this for specific purpose.
* Train a model using the given dataset and parameters.
* Developers can implement this instead of `fit()` to avoid dealing with schema validation
* and copying parameters into the model.
*
* @param dataset Intermediate training dataframe
* @param dataset Training dataset
* @return Fitted model
*/
protected def postprocess(dataset: DataFrame): Unit = {
this match {
case _: HasHandlePersistence =>
if (dataset.storageLevel != StorageLevel.NONE) {
dataset.unpersist(blocking = false)
}
case _ =>
}
}
protected def train(dataset: Dataset[_]): M

/**
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
@Since("2.0.0")
def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)

override protected def preprocess(dataset: Dataset[_]): DataFrame = {
override def fit(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val cols = collection.mutable.ArrayBuffer[Column]()
cols.append(col($(featuresCol)))

Expand All @@ -393,7 +393,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
cols.append(col($(offsetCol)).cast(DoubleType).as($(offsetCol), offsetMeta))
}

dataset.select(cols: _*)
val selected = dataset.select(cols: _*)
train(selected)
}

override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
Expand Down