-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-4894][mllib] Added Bernoulli option to NaiveBayes model in mllib #4087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ce73c63
4a3676d
0313c0c
76e5b0f
d9477ed
3891bf2
5a4a534
b61b5e2
3730572
b93aaf6
7622b0c
dc65374
85f298f
e016569
ea09b28
900b586
b85b0c9
c298e78
2d0c1ba
e2d925e
fb0a5c7
01baad7
bea62af
18f3219
a22d670
852a727
6a8f383
9ad89ca
2224b15
acb69af
f3c8994
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…ype parameter was added. Updated tests. Also updated ModelType enum-like type.
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,8 @@ import org.apache.spark.mllib.util.{Loader, Saveable} | |
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, SQLContext} | ||
|
|
||
| import NaiveBayes.ModelType.{Bernoulli, Multinomial} | ||
|
|
||
|
|
||
| /** | ||
| * Model for Naive Bayes Classifiers. | ||
|
|
@@ -54,7 +56,7 @@ class NaiveBayesModel private[mllib] ( | |
| extends ClassificationModel with Serializable with Saveable { | ||
|
|
||
| private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = | ||
| this(labels, pi, theta, NaiveBayes.Multinomial) | ||
| this(labels, pi, theta, Multinomial) | ||
|
|
||
| /** A Java-friendly constructor that takes three Iterable parameters. */ | ||
| private[mllib] def this( | ||
|
|
@@ -70,10 +72,13 @@ class NaiveBayesModel private[mllib] ( | |
| // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra | ||
| // application of this condition (in predict function). | ||
| private val (brzNegTheta, brzNegThetaSum) = modelType match { | ||
| case NaiveBayes.Multinomial => (None, None) | ||
| case NaiveBayes.Bernoulli => | ||
| case Multinomial => (None, None) | ||
| case Bernoulli => | ||
| val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) | ||
| (Option(negTheta), Option(brzSum(negTheta, Axis._1))) | ||
| case _ => | ||
| // This should never happen. | ||
| throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") | ||
| } | ||
|
|
||
| override def predict(testData: RDD[Vector]): RDD[Double] = { | ||
|
|
@@ -86,29 +91,32 @@ class NaiveBayesModel private[mllib] ( | |
|
|
||
| override def predict(testData: Vector): Double = { | ||
| modelType match { | ||
| case NaiveBayes.Multinomial => | ||
| case Multinomial => | ||
| labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) | ||
| case NaiveBayes.Bernoulli => | ||
| case Bernoulli => | ||
| labels (brzArgmax (brzPi + | ||
| (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) | ||
| case _ => | ||
| // This should never happen. | ||
| throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") | ||
| } | ||
| } | ||
|
|
||
| override def save(sc: SparkContext, path: String): Unit = { | ||
| val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString) | ||
| NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) | ||
| val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString) | ||
| NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) | ||
| } | ||
|
|
||
| override protected def formatVersion: String = "1.0" | ||
| override protected def formatVersion: String = "2.0" | ||
| } | ||
|
|
||
| object NaiveBayesModel extends Loader[NaiveBayesModel] { | ||
|
|
||
| import org.apache.spark.mllib.util.Loader._ | ||
|
|
||
| private object SaveLoadV1_0 { | ||
| private[mllib] object SaveLoadV2_0 { | ||
|
|
||
| def thisFormatVersion: String = "1.0" | ||
| def thisFormatVersion: String = "2.0" | ||
|
|
||
| /** Hard-code class name string in case it changes in the future */ | ||
| def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" | ||
|
|
@@ -127,8 +135,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { | |
| // Create JSON metadata. | ||
| val metadata = compact(render( | ||
| ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ | ||
| ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length) ~ | ||
| ("modelType" -> data.modelType))) | ||
| ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) | ||
| sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) | ||
|
|
||
| // Create Parquet data. | ||
|
|
@@ -151,36 +158,82 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { | |
| val modelType = NaiveBayes.ModelType.fromString(data.getString(3)) | ||
| new NaiveBayesModel(labels, pi, theta, modelType) | ||
| } | ||
|
|
||
| } | ||
|
|
||
| override def load(sc: SparkContext, path: String): NaiveBayesModel = { | ||
| def getModelType(metadata: JValue): NaiveBayes.ModelType = { | ||
| implicit val formats = DefaultFormats | ||
| NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String]) | ||
| private[mllib] object SaveLoadV1_0 { | ||
|
|
||
| def thisFormatVersion: String = "1.0" | ||
|
|
||
| /** Hard-code class name string in case it changes in the future */ | ||
| def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" | ||
|
|
||
| /** Model data for model import/export */ | ||
| case class Data( | ||
| labels: Array[Double], | ||
| pi: Array[Double], | ||
| theta: Array[Array[Double]]) | ||
|
|
||
| def save(sc: SparkContext, path: String, data: Data): Unit = { | ||
| val sqlContext = new SQLContext(sc) | ||
| import sqlContext.implicits._ | ||
|
|
||
| // Create JSON metadata. | ||
| val metadata = compact(render( | ||
| ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ | ||
| ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) | ||
| sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) | ||
|
|
||
| // Create Parquet data. | ||
| val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() | ||
| dataRDD.saveAsParquetFile(dataPath(path)) | ||
| } | ||
|
|
||
| def load(sc: SparkContext, path: String): NaiveBayesModel = { | ||
| val sqlContext = new SQLContext(sc) | ||
| // Load Parquet data. | ||
| val dataRDD = sqlContext.parquetFile(dataPath(path)) | ||
| // Check schema explicitly since erasure makes it hard to use match-case for checking. | ||
| checkSchema[Data](dataRDD.schema) | ||
| val dataArray = dataRDD.select("labels", "pi", "theta").take(1) | ||
| assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") | ||
| val data = dataArray(0) | ||
| val labels = data.getAs[Seq[Double]](0).toArray | ||
| val pi = data.getAs[Seq[Double]](1).toArray | ||
| val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray | ||
| new NaiveBayesModel(labels, pi, theta) | ||
| } | ||
| } | ||
|
|
||
| override def load(sc: SparkContext, path: String): NaiveBayesModel = { | ||
| val (loadedClassName, version, metadata) = loadMetadata(sc, path) | ||
| val classNameV1_0 = SaveLoadV1_0.thisClassName | ||
| (loadedClassName, version) match { | ||
| val classNameV2_0 = SaveLoadV2_0.thisClassName | ||
| val (model, numFeatures, numClasses) = (loadedClassName, version) match { | ||
| case (className, "1.0") if className == classNameV1_0 => | ||
| val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) | ||
| val model = SaveLoadV1_0.load(sc, path) | ||
| assert(model.pi.size == numClasses, | ||
| s"NaiveBayesModel.load expected $numClasses classes," + | ||
| s" but class priors vector pi had ${model.pi.size} elements") | ||
| assert(model.theta.size == numClasses, | ||
| s"NaiveBayesModel.load expected $numClasses classes," + | ||
| s" but class conditionals array theta had ${model.theta.size} elements") | ||
| assert(model.theta.forall(_.size == numFeatures), | ||
| s"NaiveBayesModel.load expected $numFeatures features," + | ||
| s" but class conditionals array theta had elements of size:" + | ||
| s" ${model.theta.map(_.size).mkString(",")}") | ||
| assert(model.modelType == getModelType(metadata)) | ||
| model | ||
| (model, numFeatures, numClasses) | ||
| case (className, "2.0") if className == classNameV2_0 => | ||
| val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) | ||
| val model = SaveLoadV2_0.load(sc, path) | ||
| (model, numFeatures, numClasses) | ||
| case _ => throw new Exception( | ||
| s"NaiveBayesModel.load did not recognize model with (className, format version):" + | ||
| s"($loadedClassName, $version). Supported:\n" + | ||
| s" ($classNameV1_0, 1.0)") | ||
| } | ||
| assert(model.pi.size == numClasses, | ||
| s"NaiveBayesModel.load expected $numClasses classes," + | ||
| s" but class priors vector pi had ${model.pi.size} elements") | ||
| assert(model.theta.size == numClasses, | ||
| s"NaiveBayesModel.load expected $numClasses classes," + | ||
| s" but class conditionals array theta had ${model.theta.size} elements") | ||
| assert(model.theta.forall(_.size == numFeatures), | ||
| s"NaiveBayesModel.load expected $numFeatures features," + | ||
| s" but class conditionals array theta had elements of size:" + | ||
| s" ${model.theta.map(_.size).mkString(",")}") | ||
| model | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -197,9 +250,9 @@ class NaiveBayes private ( | |
| private var lambda: Double, | ||
| private var modelType: NaiveBayes.ModelType) extends Serializable with Logging { | ||
|
|
||
| def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) | ||
| def this(lambda: Double) = this(lambda, Multinomial) | ||
|
|
||
| def this() = this(1.0, NaiveBayes.Multinomial) | ||
| def this() = this(1.0, Multinomial) | ||
|
|
||
| /** Set the smoothing parameter. Default: 1.0. */ | ||
| def setLambda(lambda: Double): NaiveBayes = { | ||
|
|
@@ -210,9 +263,22 @@ class NaiveBayes private ( | |
| /** Get the smoothing parameter. */ | ||
| def getLambda: Double = lambda | ||
|
|
||
| /** Set the model type. Default: Multinomial. */ | ||
| def setModelType(model: NaiveBayes.ModelType): NaiveBayes = { | ||
| this.modelType = model | ||
| /** | ||
| * Set the model type using a string (case-insensitive). | ||
| * Supported options: "multinomial" and "bernoulli". | ||
| * (default: multinomial) | ||
| */ | ||
| def setModelType(modelType: String): NaiveBayes = { | ||
| setModelType(NaiveBayes.ModelType.fromString(modelType)) | ||
| } | ||
|
|
||
| /** | ||
| * Set the model type. | ||
| * Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]] | ||
| * (default: Multinomial) | ||
| */ | ||
| def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = { | ||
| this.modelType = modelType | ||
| this | ||
| } | ||
|
|
||
|
|
@@ -270,8 +336,11 @@ class NaiveBayes private ( | |
| labels(i) = label | ||
| pi(i) = math.log(n + lambda) - piLogDenom | ||
| val thetaLogDenom = modelType match { | ||
| case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) | ||
| case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda) | ||
| case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) | ||
| case Bernoulli => math.log(n + 2.0 * lambda) | ||
| case _ => | ||
| // This should never happen. | ||
| throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") | ||
| } | ||
| var j = 0 | ||
| while (j < numFeatures) { | ||
|
|
@@ -317,7 +386,7 @@ object NaiveBayes { | |
| * @param lambda The smoothing parameter | ||
| */ | ||
| def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { | ||
| new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input) | ||
| new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -339,12 +408,45 @@ object NaiveBayes { | |
| * multinomial or bernoulli | ||
| */ | ||
| def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm actually wondering if we even need to provide this static train method. Users can use the builder pattern new NaiveBayes().setLambda().setModelType().run() instead. Could you please remove it? Sorry for not noticing this earlier!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we remove this static train method should we also remove the static train method that just includes lambda (line 326). Otherwise the train calls are inconsistent for setting different model parameters (lambda and modelType).
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't call it inconsistent, but it does limit users who want static methods. Adding new train() methods for every new parameter ended up hurting us for trees & ensembles because the list became so long, so I'd still vote for removing this static train method. We could even deprecate the other static train method and tell users to use the builder pattern instead. If you feel strongly about it, though, I'm OK with leaving it as is (since there are not very many parameters for NB). |
||
| new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input) | ||
| new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input) | ||
| } | ||
|
|
||
| /** Provides static methods for using ModelType. */ | ||
| sealed abstract class ModelType extends Serializable | ||
|
|
||
| object ModelType extends Serializable { | ||
|
|
||
| /** | ||
| * Get the model type from a string. | ||
| * @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive) | ||
| */ | ||
| def fromString(modelType: String): ModelType = modelType.toLowerCase match { | ||
| case "multinomial" => Multinomial | ||
| case "bernoulli" => Bernoulli | ||
| case _ => | ||
| throw new IllegalArgumentException( | ||
| s"NaiveBayes.ModelType.fromString did not recognize string: $modelType") | ||
| } | ||
|
|
||
| final val Multinomial: ModelType = { | ||
| case object Multinomial extends ModelType with Serializable { | ||
| override def toString: String = "multinomial" | ||
| } | ||
| Multinomial | ||
| } | ||
|
|
||
| final val Bernoulli: ModelType = { | ||
| case object Bernoulli extends ModelType with Serializable { | ||
| override def toString: String = "bernoulli" | ||
| } | ||
| Bernoulli | ||
| } | ||
| } | ||
|
|
||
| /** Java-friendly accessor for supported ModelType options */ | ||
| final val modelTypes = ModelType | ||
|
|
||
| /* | ||
| object MODELTYPE extends Serializable{ | ||
| final val MULTINOMIAL_STRING = "multinomial" | ||
| final val BERNOULLI_STRING = "bernoulli" | ||
|
|
@@ -368,6 +470,6 @@ object NaiveBayes { | |
| final val Bernoulli: ModelType = new ModelType { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add doc, perhaps something like "Constant for specifying ModelType parameter: Bernoulli model" |
||
| override def toString: String = ModelType.BERNOULLI_STRING | ||
| } | ||
|
|
||
| */ | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add getModelType method