Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix RegressionMetrics tests
  • Loading branch information
Feynman Liang committed Jul 13, 2015
commit c235de066583f2f0980de040b2aa879b5072a05a
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
val predictionAndObservations = sc.parallelize(
Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
assert(metrics.explainedVariance ~== 2.05729 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
assert(metrics.r2 ~== 0.84582 absTol 1E-5, "r2 score mismatch")
}

test("regression metrics with complete fitting") {
val predictionAndObservations = sc.parallelize(
Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
assert(metrics.explainedVariance ~== 2.89583 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
Expand Down