Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.ml.classification

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
@Experimental
class NaiveBayes(override val uid: String)
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams {
with NaiveBayesParams with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("nb"))

Expand Down Expand Up @@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String)
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
}

@Since("1.6.0")
object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {

@Since("1.6.0")
override def load(path: String): NaiveBayes = super.load(path)
}

/**
* :: Experimental ::
* Model produced by [[NaiveBayes]]
Expand All @@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
with NaiveBayesParams with MLWritable {

import OldNaiveBayes.{Bernoulli, Multinomial}

Expand Down Expand Up @@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] (
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
}

@Since("1.6.0")
override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this)
}

private[ml] object NaiveBayesModel {
@Since("1.6.0")
object NaiveBayesModel extends MLReadable[NaiveBayesModel] {

/** Convert a model from the old API */
def fromOld(
private[ml] def fromOld(
oldModel: OldNaiveBayesModel,
parent: NaiveBayes): NaiveBayesModel = {
val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
Expand All @@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel {
oldModel.theta.flatten, true)
new NaiveBayesModel(uid, pi, theta)
}

@Since("1.6.0")
override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader

@Since("1.6.0")
override def load(path: String): NaiveBayesModel = super.load(path)

/** [[MLWriter]] instance for [[NaiveBayesModel]] */
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {

private case class Data(pi: Vector, theta: Matrix)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: pi, theta
val data = Data(instance.pi, instance.theta)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {

/** Checked against metadata when loading model */
private val className = classOf[NaiveBayesModel].getName

override def load(path: String): NaiveBayesModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
val pi = data.getAs[Vector](0)
val theta = data.getAs[Matrix](1)
val model = new NaiveBayesModel(metadata.uid, pi, theta)

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
67 changes: 61 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@

package org.apache.spark.ml.clustering

import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.{DataFrame, Row}


/**
* Common params for KMeans and KMeansModel
*/
Expand Down Expand Up @@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Experimental
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with MLWritable {

@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
Expand Down Expand Up @@ -129,6 +131,52 @@ class KMeansModel private[ml] (
val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
parentModel.computeCost(data)
}

@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
}

@Since("1.6.0")
object KMeansModel extends MLReadable[KMeansModel] {

@Since("1.6.0")
override def read: MLReader[KMeansModel] = new KMeansModelReader

@Since("1.6.0")
override def load(path: String): KMeansModel = super.load(path)

/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {

private case class Data(clusterCenters: Array[Vector])

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: cluster centers
val data = Data(instance.clusterCenters)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class KMeansModelReader extends MLReader[KMeansModel] {

/** Checked against metadata when loading model */
private val className = classOf[KMeansModel].getName

override def load(path: String): KMeansModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head()
val clusterCenters = data.getAs[Seq[Vector]](0).toArray
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}

/**
Expand All @@ -141,7 +189,7 @@ class KMeansModel private[ml] (
@Experimental
class KMeans @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends Estimator[KMeansModel] with KMeansParams {
extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {

setDefault(
k -> 2,
Expand Down Expand Up @@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") (
}
}

@Since("1.6.0")
object KMeans extends DefaultParamsReadable[KMeans] {

@Since("1.6.0")
override def load(path: String): KMeans = super.load(path)
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV}

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.{DataFrame, Row}

class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

@transient var dataset: DataFrame = _

override def beforeAll(): Unit = {
super.beforeAll()

val pi = Array(0.5, 0.1, 0.4).map(math.log)
val theta = Array(
Array(0.70, 0.10, 0.10, 0.10), // label 0
Array(0.10, 0.70, 0.10, 0.10), // label 1
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))

class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
}

def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count {
Expand Down Expand Up @@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli")
}

test("read/write") {
def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
assert(model.pi === model2.pi)
assert(model.theta === model2.theta)
}
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
}
}

object NaiveBayesSuite {

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"smoothing" -> 0.1
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,15 @@
package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}

private[clustering] case class TestRow(features: Vector)

object KMeansSuite {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}
}

class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

final val k = 5
@transient var dataset: DataFrame = _
Expand Down Expand Up @@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
}

test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
}
val kmeans = new KMeans()
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
}
}

object KMeansSuite {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"k" -> 3,
"maxIter" -> 2,
"tol" -> 0.01
)
}