Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added unit test
  • Loading branch information
manishamde committed Apr 30, 2014
commit 718506b2a0146a5794261a553847d363b7dfb932
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object DecisionTreeRunner {
algo: Algo = Classification,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 20)
maxBins: Int = 100)

def main(args: Array[String]) {
val defaultParams = Params()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.model.Split
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vectors
Expand Down Expand Up @@ -390,6 +391,53 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
}

test("test second level node building with/without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)

val leftFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put a space after ,.

val rightFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),1)
val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter))
val parentImpurities = Array(0.5, 0.5, 0.5)

// Single group second level tree construction.
val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters,
splits, bins, 10)
assert(bestSplits.length === 2)
assert(bestSplits(0)._2.gain > 0)
assert(bestSplits(1)._2.gain > 0)

// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
// level tree construction.
val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1,
filters, splits, bins, 0)
assert(bestSplitsWithGroups.length === 2)
assert(bestSplitsWithGroups(0)._2.gain > 0)
assert(bestSplitsWithGroups(1)._2.gain > 0)

// Verify whether the splits obtained using single group and multiple group level
// construction strategies are the same.
for (i <- 0 until bestSplits.length) {
assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
}

}

}

object DecisionTreeSuite {
Expand All @@ -412,6 +460,20 @@ object DecisionTreeSuite {
arr
}

def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put a space between ) and {.

if (i < 600){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

){ -> ) {

val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
} else {
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
}
arr
}

def generateCategoricalDataPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
Expand Down