Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ce73c63
added Bernoulli option to niave bayes model in mllib, added optional …
leahmcguire Jan 16, 2015
4a3676d
Updated changes re-comments. Got rid of verbose populateMatrix method…
leahmcguire Jan 21, 2015
0313c0c
fixed style error in NaiveBayes.scala
leahmcguire Jan 21, 2015
76e5b0f
removed unnecessary sort from test
leahmcguire Jan 26, 2015
d9477ed
removed old inaccurate comment from test suite for mllib naive bayes
leahmcguire Feb 26, 2015
3891bf2
synced with apache spark and resolved merge conflict
leahmcguire Feb 27, 2015
5a4a534
fixed scala style error in NaiveBayes
leahmcguire Feb 27, 2015
b61b5e2
added back compatable constructor to NaiveBayesModel to fix MIMA test…
leahmcguire Mar 2, 2015
3730572
modified NB model type to be more Java-friendly
jkbradley Mar 3, 2015
b93aaf6
Merge pull request #1 from jkbradley/nb-model-type
leahmcguire Mar 5, 2015
7622b0c
added comments and fixed style as per rb
leahmcguire Mar 5, 2015
dc65374
integrated model type fix
leahmcguire Mar 5, 2015
85f298f
Merge remote-tracking branch 'upstream/master'
leahmcguire Mar 5, 2015
e016569
updated test suite with model type fix
leahmcguire Mar 5, 2015
ea09b28
Merge remote-tracking branch 'upstream/master'
leahmcguire Mar 5, 2015
900b586
fixed model call so that uses type argument
leahmcguire Mar 5, 2015
b85b0c9
Merge remote-tracking branch 'upstream/master'
leahmcguire Mar 5, 2015
c298e78
fixed scala style errors
leahmcguire Mar 5, 2015
2d0c1ba
fixed typo in NaiveBayes
leahmcguire Mar 5, 2015
e2d925e
fixed nonserializable error that was causing naivebayes test failures
leahmcguire Mar 7, 2015
fb0a5c7
removed typo
leahmcguire Mar 9, 2015
01baad7
made fixes from code review
leahmcguire Mar 11, 2015
bea62af
put back in constructor for NaiveBayes
leahmcguire Mar 12, 2015
18f3219
removed private from naive bayes constructor for lambda only
leahmcguire Mar 12, 2015
a22d670
changed NaiveBayesModel modelType parameter back to NaiveBayes.ModelT…
leahmcguire Mar 17, 2015
852a727
merged with upstream master
leahmcguire Mar 21, 2015
6a8f383
Added new model save/load format 2.0 for NaiveBayesModel after modelT…
jkbradley Mar 22, 2015
9ad89ca
removed old code
jkbradley Mar 22, 2015
2224b15
Merge pull request #2 from jkbradley/leahmcguire-master
leahmcguire Mar 24, 2015
acb69af
removed enum type and replaces all modelType parameters with strings
leahmcguire Mar 28, 2015
f3c8994
changed checks on model type to requires
leahmcguire Mar 31, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added new model save/load format 2.0 for NaiveBayesModel after modelT…
…ype parameter was added. Updated tests. Also updated ModelType enum-like type.
  • Loading branch information
jkbradley committed Mar 22, 2015
commit 6a8f3830e25fdad24e280ae68fb83ba522d23c1b
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}

import NaiveBayes.ModelType.{Bernoulli, Multinomial}


/**
* Model for Naive Bayes Classifiers.
Expand All @@ -54,7 +56,7 @@ class NaiveBayesModel private[mllib] (
extends ClassificationModel with Serializable with Saveable {

private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, NaiveBayes.Multinomial)
this(labels, pi, theta, Multinomial)

/** A Java-friendly constructor that takes three Iterable parameters. */
private[mllib] def this(
Expand All @@ -70,10 +72,13 @@ class NaiveBayesModel private[mllib] (
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function).
private val (brzNegTheta, brzNegThetaSum) = modelType match {
case NaiveBayes.Multinomial => (None, None)
case NaiveBayes.Bernoulli =>
case Multinomial => (None, None)
case Bernoulli =>
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
(Option(negTheta), Option(brzSum(negTheta, Axis._1)))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
}

override def predict(testData: RDD[Vector]): RDD[Double] = {
Expand All @@ -86,29 +91,32 @@ class NaiveBayesModel private[mllib] (

override def predict(testData: Vector): Double = {
modelType match {
case NaiveBayes.Multinomial =>
case Multinomial =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
case NaiveBayes.Bernoulli =>
case Bernoulli =>
labels (brzArgmax (brzPi +
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
}
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString)
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
}

override protected def formatVersion: String = "1.0"
override protected def formatVersion: String = "2.0"
}

object NaiveBayesModel extends Loader[NaiveBayesModel] {

import org.apache.spark.mllib.util.Loader._

private object SaveLoadV1_0 {
private[mllib] object SaveLoadV2_0 {

def thisFormatVersion: String = "1.0"
def thisFormatVersion: String = "2.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
Expand All @@ -127,8 +135,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Create JSON metadata.
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length) ~
("modelType" -> data.modelType)))
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

// Create Parquet data.
Expand All @@ -151,36 +158,82 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
new NaiveBayesModel(labels, pi, theta, modelType)
}

}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
implicit val formats = DefaultFormats
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
private[mllib] object SaveLoadV1_0 {

def thisFormatVersion: String = "1.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"

/** Model data for model import/export */
case class Data(
labels: Array[Double],
pi: Array[Double],
theta: Array[Array[Double]])

def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

// Create JSON metadata.
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
dataRDD.saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
new NaiveBayesModel(labels, pi, theta)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
val classNameV2_0 = SaveLoadV2_0.thisClassName
val (model, numFeatures, numClasses) = (loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
assert(model.modelType == getModelType(metadata))
model
(model, numFeatures, numClasses)
case (className, "2.0") if className == classNameV2_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV2_0.load(sc, path)
(model, numFeatures, numClasses)
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
model
}
}

Expand All @@ -197,9 +250,9 @@ class NaiveBayes private (
private var lambda: Double,
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add getModelType method


def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
def this(lambda: Double) = this(lambda, Multinomial)

def this() = this(1.0, NaiveBayes.Multinomial)
def this() = this(1.0, Multinomial)

/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
Expand All @@ -210,9 +263,22 @@ class NaiveBayes private (
/** Get the smoothing parameter. */
def getLambda: Double = lambda

/** Set the model type. Default: Multinomial. */
def setModelType(model: NaiveBayes.ModelType): NaiveBayes = {
this.modelType = model
/**
* Set the model type using a string (case-insensitive).
* Supported options: "multinomial" and "bernoulli".
* (default: multinomial)
*/
def setModelType(modelType: String): NaiveBayes = {
setModelType(NaiveBayes.ModelType.fromString(modelType))
}

/**
* Set the model type.
* Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]]
* (default: Multinomial)
*/
def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = {
this.modelType = modelType
this
}

Expand Down Expand Up @@ -270,8 +336,11 @@ class NaiveBayes private (
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda)
case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case Bernoulli => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
}
var j = 0
while (j < numFeatures) {
Expand Down Expand Up @@ -317,7 +386,7 @@ object NaiveBayes {
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input)
new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input)
}

/**
Expand All @@ -339,12 +408,45 @@ object NaiveBayes {
* multinomial or bernoulli
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually wondering if we even need to provide this static train method. Users can use the builder pattern new NaiveBayes().setLambda().setModelType().run() instead. Could you please remove it? Sorry for not noticing this earlier!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove this static train method should we also remove the static train method that just includes lambda (line 326). Otherwise the train calls are inconsistent for setting different model parameters (lambda and modelType).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't call it inconsistent, but it does limit users who want static methods. Adding new train() methods for every new parameter ended up hurting us for trees & ensembles because the list became so long, so I'd still vote for removing this static train method. We could even deprecate the other static train method and tell users to use the builder pattern instead. If you feel strongly about it, though, I'm OK with leaving it as is (since there are not very many parameters for NB).

new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input)
new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input)
}

/** Provides static methods for using ModelType. */
sealed abstract class ModelType extends Serializable

object ModelType extends Serializable {

/**
* Get the model type from a string.
* @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
*/
def fromString(modelType: String): ModelType = modelType.toLowerCase match {
case "multinomial" => Multinomial
case "bernoulli" => Bernoulli
case _ =>
throw new IllegalArgumentException(
s"NaiveBayes.ModelType.fromString did not recognize string: $modelType")
}

final val Multinomial: ModelType = {
case object Multinomial extends ModelType with Serializable {
override def toString: String = "multinomial"
}
Multinomial
}

final val Bernoulli: ModelType = {
case object Bernoulli extends ModelType with Serializable {
override def toString: String = "bernoulli"
}
Bernoulli
}
}

/** Java-friendly accessor for supported ModelType options */
final val modelTypes = ModelType

/*
object MODELTYPE extends Serializable{
final val MULTINOMIAL_STRING = "multinomial"
final val BERNOULLI_STRING = "bernoulli"
Expand All @@ -368,6 +470,6 @@ object NaiveBayes {
final val Bernoulli: ModelType = new ModelType {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add doc, perhaps something like "Constant for specifying ModelType parameter: Bernoulli model"

override def toString: String = ModelType.BERNOULLI_STRING
}

*/
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@

package org.apache.spark.mllib.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext sc;
Expand Down Expand Up @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception {
// Should be able to get the first prediction.
predictions.first();
}

@Test
public void testModelTypeSetters() {
NaiveBayes nb = new NaiveBayes()
.setModelType(NaiveBayes.modelTypes().Bernoulli())
.setModelType(NaiveBayes.modelTypes().Multinomial());
}
}
Loading