@@ -46,19 +46,19 @@ class NaiveBayesModel private[mllib] (
4646 val labels : Array [Double ],
4747 val pi : Array [Double ],
4848 val theta : Array [Array [Double ]],
49- val modelType : NaiveBayes . ModelType )
49+ val modelType : String )
5050 extends ClassificationModel with Serializable with Saveable {
5151
5252 def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
53- this (labels, pi, theta, NaiveBayes .Multinomial )
53+ this (labels, pi, theta, NaiveBayes .Multinomial .toString )
5454
5555 private val brzPi = new BDV [Double ](pi)
5656 private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
5757
5858 // Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
5959 // this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
6060 // of this condition in predict function
61- private val (brzNegTheta, brzNegThetaSum) = modelType match {
61+ private val (brzNegTheta, brzNegThetaSum) = NaiveBayes . ModelType .fromString( modelType) match {
6262 case NaiveBayes .Multinomial => (None , None )
6363 case NaiveBayes .Bernoulli =>
6464 val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
@@ -74,7 +74,7 @@ class NaiveBayesModel private[mllib] (
7474 }
7575
7676 override def predict (testData : Vector ): Double = {
77- modelType match {
77+ NaiveBayes . ModelType .fromString( modelType) match {
7878 case NaiveBayes .Multinomial =>
7979 labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
8080 case NaiveBayes .Bernoulli =>
@@ -84,7 +84,7 @@ class NaiveBayesModel private[mllib] (
8484 }
8585
8686 override def save (sc : SparkContext , path : String ): Unit = {
87- val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString )
87+ val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType)
8888 NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
8989 }
9090
@@ -137,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
137137 val labels = data.getAs[Seq [Double ]](0 ).toArray
138138 val pi = data.getAs[Seq [Double ]](1 ).toArray
139139 val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
140- val modelType = NaiveBayes .ModelType .fromString(data.getString(3 ))
140+ val modelType = NaiveBayes .ModelType .fromString(data.getString(3 )).toString
141141 new NaiveBayesModel (labels, pi, theta, modelType)
142142 }
143143 }
144144
145145 override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
146- def getModelType (metadata : JValue ): NaiveBayes . ModelType = {
146+ def getModelType (metadata : JValue ): String = {
147147 implicit val formats = DefaultFormats
148- NaiveBayes .ModelType .fromString((metadata \ " modelType" ).extract[String ])
148+ NaiveBayes .ModelType .fromString((metadata \ " modelType" ).extract[String ]).toString
149149 }
150150 val (loadedClassName, version, metadata) = loadMetadata(sc, path)
151151 val classNameV1_0 = SaveLoadV1_0 .thisClassName
@@ -265,7 +265,7 @@ class NaiveBayes private (
265265 i += 1
266266 }
267267
268- new NaiveBayesModel (labels, pi, theta, modelType)
268+ new NaiveBayesModel (labels, pi, theta, modelType.toString )
269269 }
270270}
271271
0 commit comments