Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ import scala.collection.mutable

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS}
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkException, Logging}
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.{Model, Estimator}
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.linalg.BLAS
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{Logging, SparkException}

/**
* Params for accelerated failure time (AFT) regression.
Expand Down Expand Up @@ -120,7 +120,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
@Experimental
@Since("1.6.0")
class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String)
extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging {
extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
with DefaultParamsWritable with Logging {

@Since("1.6.0")
def this() = this(Identifiable.randomUID("aftSurvReg"))
Expand Down Expand Up @@ -243,6 +244,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra)
}

@Since("1.6.0")
object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] {

@Since("1.6.0")
override def load(path: String): AFTSurvivalRegression = super.load(path)
}

/**
* :: Experimental ::
* Model produced by [[AFTSurvivalRegression]].
Expand All @@ -254,7 +262,7 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("1.6.0") val coefficients: Vector,
@Since("1.6.0") val intercept: Double,
@Since("1.6.0") val scale: Double)
extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams {
extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable {

/** @group setParam */
@Since("1.6.0")
Expand Down Expand Up @@ -312,6 +320,58 @@ class AFTSurvivalRegressionModel private[ml] (
copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra)
.setParent(parent)
}

@Since("1.6.0")
override def write: MLWriter =
new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this)
}

@Since("1.6.0")
object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] {

@Since("1.6.0")
override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader

@Since("1.6.0")
override def load(path: String): AFTSurvivalRegressionModel = super.load(path)

/** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */
private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter (
instance: AFTSurvivalRegressionModel
) extends MLWriter with Logging {

private case class Data(coefficients: Vector, intercept: Double, scale: Double)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: coefficients, intercept, scale
val data = Data(instance.coefficients, instance.intercept, instance.scale)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] {

/** Checked against metadata when loading model */
private val className = classOf[AFTSurvivalRegressionModel].getName

override def load(path: String): AFTSurvivalRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("coefficients", "intercept", "scale").head()
val coefficients = data.getAs[Vector](0)
val intercept = data.getDouble(1)
val scale = data.getDouble(2)
val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale)

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

package org.apache.spark.ml.regression

import org.apache.hadoop.fs.Path

import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel}
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -127,7 +131,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
@Since("1.5.0")
@Experimental
class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase {
extends Estimator[IsotonicRegressionModel]
with IsotonicRegressionBase with DefaultParamsWritable {

@Since("1.5.0")
def this() = this(Identifiable.randomUID("isoReg"))
Expand Down Expand Up @@ -179,6 +184,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
}
}

@Since("1.6.0")
object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] {

@Since("1.6.0")
override def load(path: String): IsotonicRegression = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by IsotonicRegression.
Expand All @@ -194,7 +206,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
class IsotonicRegressionModel private[ml] (
override val uid: String,
private val oldModel: MLlibIsotonicRegressionModel)
extends Model[IsotonicRegressionModel] with IsotonicRegressionBase {
extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable {

/** @group setParam */
@Since("1.5.0")
Expand Down Expand Up @@ -240,4 +252,61 @@ class IsotonicRegressionModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false)
}

@Since("1.6.0")
override def write: MLWriter =
new IsotonicRegressionModelWriter(this)
}

@Since("1.6.0")
object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] {

@Since("1.6.0")
override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader

@Since("1.6.0")
override def load(path: String): IsotonicRegressionModel = super.load(path)

/** [[MLWriter]] instance for [[IsotonicRegressionModel]] */
private[IsotonicRegressionModel] class IsotonicRegressionModelWriter (
instance: IsotonicRegressionModel
) extends MLWriter with Logging {

private case class Data(
boundaries: Array[Double],
predictions: Array[Double],
isotonic: Boolean)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: boundaries, predictions, isotonic
val data = Data(
instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] {

/** Checked against metadata when loading model */
private val className = classOf[IsotonicRegressionModel].getName

override def load(path: String): IsotonicRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("boundaries", "predictions", "isotonic").head()
val boundaries = data.getAs[Seq[Double]](0).toArray
val predictions = data.getAs[Seq[Double]](1).toArray
val isotonic = data.getBoolean(2)
val model = new IsotonicRegressionModel(
metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic))

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}

class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
class AFTSurvivalRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

@transient var datasetUnivariate: DataFrame = _
@transient var datasetMultivariate: DataFrame = _
Expand Down Expand Up @@ -332,4 +333,32 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
assert(prediction ~== model.predict(features) relTol 1E-5)
}
}

test("read/write") {
def checkModelData(
model: AFTSurvivalRegressionModel,
model2: AFTSurvivalRegressionModel): Unit = {
assert(model.intercept === model2.intercept)
assert(model.coefficients === model2.coefficients)
assert(model.scale === model2.scale)
}
val aft = new AFTSurvivalRegression()
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
}
}

object AFTSurvivalRegressionSuite {

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"fitIntercept" -> true,
"maxIter" -> 2,
"tol" -> 0.01
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ package org.apache.spark.ml.regression

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}

class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
class IsotonicRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
sqlContext.createDataFrame(
labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) }
Expand Down Expand Up @@ -164,4 +166,32 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {

assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
}

test("read/write") {
val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18))

def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = {
assert(model.boundaries === model2.boundaries)
assert(model.predictions === model2.predictions)
assert(model.isotonic === model2.isotonic)
}

val ir = new IsotonicRegression()
testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
checkModelData)
}
}

object IsotonicRegressionSuite {

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"isotonic" -> true,
"featureIndex" -> 0
)
}