Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
updated based on comments
  • Loading branch information
imatiach-msft committed Jan 12, 2019
commit 52bb65b87453aeb643e34b32a54c1de247eac322
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object TestingUtils {
* Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue,
* the relative tolerance is meaningless, so the exception will be raised to warn users.
*/
private[ml] def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs
if (x.isNaN && y.isNaN) {
return true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ abstract class Classifier[
* [0, numClasses).
*/
protected def validateLabel(label: Double, numClasses: Int): Unit = {
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.Vector
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
private[spark] case class Instance(label: Double, weight: Double, features: Vector)
private[ml] case class Instance(label: Double, weight: Double, features: Vector)

/**
* Case class that represents an instance of data point with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.util.Try
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.{Instance => NewInstance}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -93,7 +93,7 @@ private class RandomForest (
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
val instances = input.map { case LabeledPoint(label, features) =>
NewInstance(label, 1.0, features.asML)
Instance(label, 1.0, features.asML)
}
val trees: Array[NewDTModel] =
NewRandomForest.run(instances, strategy, numTrees, featureSubsetStrategy, seed.toLong, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
.setSeed(123)
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel,
DecisionTreeRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 0.05), 0.9))
MLTestingUtils.modelPredictionEquals(df, _ ~== _ relTol 0.05, 0.9))
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel,
DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses,
MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 0.1), 0.8),
MLTestingUtils.modelPredictionEquals(df, _ ~== _ relTol 0.1, 0.8),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel,
DecisionTreeRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 0.01), 1.0), seed)
MLTestingUtils.modelPredictionEquals(df, _ ~== _ relTol 0.01, 1.0), seed)
}
}

Expand Down