Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.MLReader
import org.apache.spark.ml.util.MLWriter
Expand Down Expand Up @@ -174,27 +174,31 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
}

@Since("1.6.0")
override def write: MLWriter = new Pipeline.PipelineWriter(this)
}

@Since("1.6.0")
object Pipeline extends MLReadable[Pipeline] {

@Since("1.6.0")
override def read: MLReader[Pipeline] = new PipelineReader

@Since("1.6.0")
override def load(path: String): Pipeline = super.load(path)

private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter {
private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter {

SharedReadWrite.validateStages(instance.getStages)

override protected def saveImpl(path: String): Unit =
SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
}

private[ml] class PipelineReader extends MLReader[Pipeline] {
private class PipelineReader extends MLReader[Pipeline] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.Pipeline"
private val className = classOf[Pipeline].getName

override def load(path: String): Pipeline = {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
Expand Down Expand Up @@ -333,29 +337,33 @@ class PipelineModel private[ml] (
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
}

@Since("1.6.0")
override def write: MLWriter = new PipelineModel.PipelineModelWriter(this)
}

@Since("1.6.0")
object PipelineModel extends MLReadable[PipelineModel] {

import Pipeline.SharedReadWrite

@Since("1.6.0")
override def read: MLReader[PipelineModel] = new PipelineModelReader

@Since("1.6.0")
override def load(path: String): PipelineModel = super.load(path)

private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter {
private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter {

SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])

override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
}

private[ml] class PipelineModelReader extends MLReader[PipelineModel] {
private class PipelineModelReader extends MLReader[PipelineModel] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.PipelineModel"
private val className = classOf[PipelineModel].getName

override def load(path: String): PipelineModel = {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.hadoop.fs.Path

import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] (
*
* This also does not save the [[parent]] currently.
*/
@Since("1.6.0")
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
}


@Since("1.6.0")
object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {

@Since("1.6.0")
override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader

@Since("1.6.0")
override def load(path: String): LogisticRegressionModel = super.load(path)

/** [[MLWriter]] instance for [[LogisticRegressionModel]] */
private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
private[LogisticRegressionModel]
class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
extends MLWriter with Logging {

private case class Data(
Expand All @@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
instance.coefficients)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private[classification] class LogisticRegressionModelReader
private class LogisticRegressionModelReader
extends MLReader[LogisticRegressionModel] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"
private val className = classOf[LogisticRegressionModel].getName

override def load(path: String): LogisticRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
Expand Down Expand Up @@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
* @return This MultilabelSummarizer
*/
def add(label: Double, weight: Double = 1.0): this.type = {
require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")

if (weight == 0.0) return this

Expand Down Expand Up @@ -839,7 +844,7 @@ private class LogisticAggregator(
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new instance." +
s" Expecting $dim but got ${features.size}.")
require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")

if (weight == 0.0) return this

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {

private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] {

private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
private val className = classOf[CountVectorizerModel].getName

override def load(path: String): CountVectorizerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
Expand Down
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] {

private class IDFModelReader extends MLReader[IDFModel] {

private val className = "org.apache.spark.ml.feature.IDFModel"
private val className = classOf[IDFModel].getName

override def load(path: String): IDFModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {

private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] {

private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
private val className = classOf[MinMaxScalerModel].getName

override def load(path: String): MinMaxScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {

private class StandardScalerModelReader extends MLReader[StandardScalerModel] {

private val className = "org.apache.spark.ml.feature.StandardScalerModel"
private val className = classOf[StandardScalerModel].getName

override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {

private class StringIndexerModelReader extends MLReader[StringIndexerModel] {

private val className = "org.apache.spark.ml.feature.StringIndexerModel"
private val className = classOf[StringIndexerModel].getName

override def load(path: String): StringIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] {
@Since("1.6.0")
override def load(path: String): ALSModel = super.load(path)

private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter {
private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
val extraMetadata = render("rank" -> instance.rank)
Expand All @@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] {
}
}

private[recommendation] class ALSModelReader extends MLReader[ALSModel] {
private class ALSModelReader extends MLReader[ALSModel] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.recommendation.ALSModel"
private val className = classOf[ALSModel].getName

override def load(path: String): ALSModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.regression.LinearRegressionModel"
private val className = classOf[LinearRegressionModel].getName

override def load(path: String): LinearRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ object LogisticRegressionSuite {
"regParam" -> 0.01,
"elasticNetParam" -> 0.1,
"maxIter" -> 2, // intentionally small
"fitIntercept" -> false,
"fitIntercept" -> true,
"tol" -> 0.8,
"standardization" -> false,
"threshold" -> 0.6
Expand Down