-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-4894][mllib] Added Bernoulli option to NaiveBayes model in mllib #4087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ce73c63
4a3676d
0313c0c
76e5b0f
d9477ed
3891bf2
5a4a534
b61b5e2
3730572
b93aaf6
7622b0c
dc65374
85f298f
e016569
ea09b28
900b586
b85b0c9
c298e78
2d0c1ba
e2d925e
fb0a5c7
01baad7
bea62af
18f3219
a22d670
852a727
6a8f383
9ad89ca
2224b15
acb69af
f3c8994
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,15 +49,15 @@ class NaiveBayesModel private[mllib] ( | |
| val modelType: String) | ||
| extends ClassificationModel with Serializable with Saveable { | ||
|
|
||
| def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = | ||
| private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = | ||
| this(labels, pi, theta, NaiveBayes.Multinomial.toString) | ||
|
|
||
| private val brzPi = new BDV[Double](pi) | ||
| private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t | ||
|
|
||
| // Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0 | ||
| // this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application | ||
| // of this condition in predict function | ||
| // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. | ||
| // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra | ||
| // application of this condition (in predict function). | ||
| private val (brzNegTheta, brzNegThetaSum) = NaiveBayes.ModelType.fromString(modelType) match { | ||
| case NaiveBayes.Multinomial => (None, None) | ||
| case NaiveBayes.Bernoulli => | ||
|
|
@@ -186,8 +186,6 @@ class NaiveBayes private ( | |
| private var lambda: Double, | ||
| private var modelType: NaiveBayes.ModelType) extends Serializable with Logging { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add getModelType method |
||
|
|
||
| def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) | ||
|
|
||
| def this() = this(1.0, NaiveBayes.Multinomial) | ||
|
|
||
| /** Set the smoothing parameter. Default: 1.0. */ | ||
|
|
@@ -202,6 +200,7 @@ class NaiveBayes private ( | |
| this | ||
| } | ||
|
|
||
| def getModelType(): NaiveBayes.ModelType = this.modelType | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Getters normally don't have parentheses in Spark |
||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove extra space |
||
| /** | ||
| * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. | ||
|
|
@@ -301,10 +300,9 @@ object NaiveBayes { | |
| * @param lambda The smoothing parameter | ||
| */ | ||
| def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { | ||
| new NaiveBayes(lambda).run(input) | ||
| new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input) | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. | ||
| * | ||
|
|
@@ -327,11 +325,7 @@ object NaiveBayes { | |
| new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input) | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Model types supported in Naive Bayes: | ||
| * multinomial and Bernoulli currently supported | ||
| */ | ||
| /** Provides static methods for using ModelType. */ | ||
| sealed abstract class ModelType | ||
|
|
||
| object MODELTYPE { | ||
|
|
@@ -348,10 +342,12 @@ object NaiveBayes { | |
|
|
||
| final val ModelType = MODELTYPE | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add doc, perhaps something like "Provides static methods for using ModelType" |
||
|
|
||
| /** Constant for specifying ModelType parameter: multinomial model */ | ||
| final val Multinomial: ModelType = new ModelType { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add doc, perhaps something like "Constant for specifying ModelType parameter: Multinomial model" |
||
| override def toString: String = ModelType.MULTINOMIAL_STRING | ||
| } | ||
|
|
||
| /** Constant for specifying ModelType parameter: bernoulli model */ | ||
| final val Bernoulli: ModelType = new ModelType { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add doc, perhaps something like "Constant for specifying ModelType parameter: Bernoulli model" |
||
| override def toString: String = ModelType.BERNOULLI_STRING | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be nice to expose this as the enum-like type instead of a String. Does that sound reasonable (since users use it when calling NaiveBayes anyways).
It would be good to avoid using "ModelType.fromString" in the predict() method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to change this from the enum like type to the string to fix the unit test failures. An actual enum worked but the substitute that you suggested was throwing an non-serializable error on all of the NaiveBayes tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that may have been because I didn't make those types extend Serializable. Does that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep that fixes it :P