diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala index 0187ad603a65..154fa07fc825 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -15,6 +15,16 @@ * limitations under the License. */ +/* + * Usage: + * + * sbt package + * + * spark-submit --class "org.apache.spark.examples.mllib.NaiveBayesExample" + * --master local[4] + * examples/target/scala-2.11/spark-examples_2.11-2.0.0-SNAPSHOT.jar + */ + // scalastyle:off println package org.apache.spark.examples.mllib @@ -23,6 +33,9 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint + +import java.io.File; +import org.apache.commons.io.FileUtils; // $example off$ object NaiveBayesExample { @@ -46,10 +59,18 @@ object NaiveBayesExample { val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() + println("model accuracy %f".format(accuracy)) // Save and load model - model.save(sc, "target/tmp/myNaiveBayesModel") - val sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel") + val outputDir = "target/tmp/myNaiveBayesModel" + FileUtils.forceDelete(new File(outputDir)) + model.save(sc, outputDir) + val sameModel = NaiveBayesModel.load(sc, outputDir) + + val samePredictionAndLabel = test.map(p => (sameModel.predict(p.features), p.label)) + val sameAccuracy = 1.0 * samePredictionAndLabel.filter(x => x._1 == x._2).count() / + test.count() + println("sameModel accuracy %f".format(sameAccuracy)) // $example off$ } }