Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
simple testSuites and run properly
  • Loading branch information
Yuewei Na authored and Yuewei Na committed Jun 7, 2016
commit 7bcabdac3d54ed9c682de5493da89f26e8a8e55a
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class Strategy @Since("1.3.0") (
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
@Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array()) extends Serializable {
@Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array(1, 1))
extends Serializable {

/**
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ object WeightedGini extends Impurity {
* in order to compute impurity from a sample.
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
* @param weights Weights of classes
* @param classWeights Weights of classes
*/
private[spark] class WeightedGiniAggregator(numClasses: Int, weights: Array[Double])
private[spark] class WeightedGiniAggregator(numClasses: Int, classWeights: Array[Double])
extends ImpurityAggregator(numClasses) with Serializable {

/**
Expand All @@ -108,7 +108,7 @@ private[spark] class WeightedGiniAggregator(numClasses: Int, weights: Array[Doub
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): WeightedGiniCalculator = {
new WeightedGiniCalculator(allStats.view(offset, offset + statsSize).toArray, weights)
new WeightedGiniCalculator(allStats.view(offset, offset + statsSize).toArray, classWeights)
}
}

Expand All @@ -117,16 +117,16 @@ private[spark] class WeightedGiniAggregator(numClasses: Int, weights: Array[Doub
* Unlike [[WeightedGiniAggregator]], this class stores its own data and is for a specific
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
* @param weights Weights of classes
* @param classWeights Weights of classes
*/
private[spark] class WeightedGiniCalculator(stats: Array[Double], weights: Array[Double])
private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights: Array[Double])
extends ImpurityCalculator(stats) {

var weightedStats = stats.zip(weights).map(x => x._1 * x._2)
var weightedStats = stats.zip(classWeights).map(x => x._1 * x._2)
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
def copy: WeightedGiniCalculator = new WeightedGiniCalculator(stats.clone(), weights.clone())
def copy: WeightedGiniCalculator = new WeightedGiniCalculator(stats.clone(), classWeights.clone())

/**
* Calculate the impurity from the stored sufficient statistics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class DecisionTreeClassifierSuite
val numClasses = 2
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
}

/*
test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") {
val dt = new DecisionTreeClassifier()
.setMaxDepth(3)
Expand All @@ -91,7 +91,7 @@ class DecisionTreeClassifierSuite
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
}
}
}
}*/

test("Multiclass classification stump with 3-ary (unordered) categorical features") {
val rdd = categoricalDataPointsForMulticlassRDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
0, 0, 0.0, 0, 0, Array()
)
val featureSamples = Array.fill(200000)(math.random)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
Expand All @@ -110,7 +110,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
0, 0, 0.0, 0, 0, Array()
)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
Expand All @@ -124,7 +124,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
0, 0, 0.0, 0, 0, Array()
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
Expand All @@ -138,7 +138,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
0, 0, 0.0, 0, 0, Array()
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
Expand Down