Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
711356b
[SPARK-3086] [SPARK-3043] [SPARK-3156] [mllib] DecisionTree aggregat…
jkbradley Sep 8, 2014
e16a8e7
SPARK-3337 Paranoid quoting in shell to allow install dirs with space…
ScrapCodes Sep 8, 2014
16a73c2
SPARK-2978. Transformation with MR shuffle semantics
sryza Sep 8, 2014
386bc24
Provide a default PYSPARK_PYTHON for python/run_tests
Sep 8, 2014
26bc765
[SQL] Minor edits to sql programming guide.
hcook Sep 8, 2014
939a322
[SPARK-3417] Use new-style classes in PySpark
mrocklin Sep 8, 2014
08ce188
[SPARK-3019] Pluggable block transfer interface (BlockTransferService)
rxin Sep 8, 2014
7db5339
[SPARK-3349][SQL] Output partitioning of limit should not be inherite…
Sep 8, 2014
50a4fa7
[SPARK-3443][MLLIB] update default values of tree:
mengxr Sep 9, 2014
ca0348e
SPARK-3423: [SQL] Implement BETWEEN for SQLParser
willb Sep 9, 2014
dc1dbf2
[SPARK-3414][SQL] Stores analyzed logical plan when registering a tem…
liancheng Sep 9, 2014
2b7ab81
[SPARK-3329][SQL] Don't depend on Hive SET pair ordering in tests.
willb Sep 9, 2014
092e2f1
SPARK-2425 Don't kill a still-running Application because of some mis…
markhamstra Sep 9, 2014
ce5cb32
[Build] Removed -Phive-thriftserver since this profile has been removed
liancheng Sep 9, 2014
c419e4f
[Docs] actorStream storageLevel default is MEMORY_AND_DISK_SER_2
melrief Sep 9, 2014
1e03cf7
[SPARK-3455] [SQL] **HOT FIX** Fix the unit test failure
chenghao-intel Sep 9, 2014
88547a0
SPARK-3422. JavaAPISuite.getHadoopInputSplits isn't used anywhere.
sryza Sep 9, 2014
f0f1ba0
SPARK-3404 [BUILD] SparkSubmitSuite fails with "spark-submit exits wi…
srowen Sep 9, 2014
2686233
[SPARK-3193]output errer info when Process exit code is not zero in t…
scwf Sep 9, 2014
02b5ac7
Minor - Fix trivial compilation warnings.
ScrapCodes Sep 9, 2014
07ee4a2
[SPARK-3176] Implement 'ABS and 'LAST' for sql
Sep 9, 2014
c110614
[SPARK-3448][SQL] Check for null in SpecificMutableRow.update
liancheng Sep 10, 2014
25b5b86
[SPARK-3458] enable python "with" statements for SparkContext
Sep 10, 2014
b734ed0
[SPARK-3395] [SQL] DSL sometimes incorrectly reuses attribute ids, br…
Sep 10, 2014
6f7a768
[SPARK-3286] - Cannot view ApplicationMaster UI when Yarn’s url schem…
Sep 10, 2014
a028330
[SPARK-3362][SQL] Fix resolution for casewhen with nulls.
adrian-wang Sep 10, 2014
f0c87dc
[SPARK-3363][SQL] Type Coercion should promote null to all other types.
adrian-wang Sep 10, 2014
26503fd
[HOTFIX] Fix scala style issue introduced by #2276.
JoshRosen Sep 10, 2014
1f4a648
SPARK-1713. Use a thread pool for launching executors.
sryza Sep 10, 2014
e4f4886
[SPARK-2096][SQL] Correctly parse dot notations
cloud-fan Sep 10, 2014
558962a
[SPARK-3411] Improve load-balancing of concurrently-submitted drivers…
WangTaoTheTonic Sep 10, 2014
79cdb9b
[SPARK-2207][SPARK-3272][MLLib]Add minimum information gain and minim…
Sep 10, 2014
84e2c8b
[SQL] Add test case with workaround for reading partitioned Avro files
marmbrus Sep 11, 2014
f92cde2
[SPARK-3447][SQL] Remove explicit conversion with JListWrapper to avo…
marmbrus Sep 11, 2014
c27718f
[SPARK-2781][SQL] Check resolution of LogicalPlans in Analyzer.
staple Sep 11, 2014
ed1980f
[SPARK-2140] Updating heap memory calculation for YARN stable and alpha.
Sep 11, 2014
1ef656e
[SPARK-3047] [PySpark] add an option to use str in textFileRDD
davies Sep 11, 2014
ca83f1e
[SPARK-2917] [SQL] Avoid table creation in logical plan analyzing for…
chenghao-intel Sep 11, 2014
4bc9e04
[SPARK-3390][SQL] sqlContext.jsonRDD fails on a complex structure of …
yhuai Sep 11, 2014
6324eb7
[Spark-3490] Disable SparkUI for tests
andrewor14 Sep 12, 2014
ce59725
[SPARK-3429] Don't include the empty string "" as a defaultAclUser
ash211 Sep 12, 2014
f858f46
SPARK-3462 push down filters and projections into Unions
Sep 12, 2014
33c7a73
SPARK-2482: Resolve sbt warnings during build
witgo Sep 12, 2014
42904b8
[SPARK-3465] fix task metrics aggregation in local mode
davies Sep 12, 2014
b8634df
[SPARK-3160] [SPARK-3494] [mllib] DecisionTree: eliminate pre-alloca…
jkbradley Sep 12, 2014
f116f76
[SPARK-2558][DOCS] Add --queue example to YARN doc
kramimus Sep 12, 2014
5333776
[PySpark] Add blank line so that Python RDD.top() docstring renders c…
rnowling Sep 12, 2014
8194fc6
[SPARK-3481] [SQL] Eliminate the error log in local Hive comparison test
chenghao-intel Sep 12, 2014
eae81b0
MAINTENANCE: Automated closing of pull requests.
pwendell Sep 12, 2014
15a5645
[SPARK-3427] [GraphX] Avoid active vertex tracking in static PageRank
ankurdave Sep 12, 2014
1d76796
SPARK-3014. Log a more informative messages in a couple failure scena…
sryza Sep 12, 2014
af25838
[SPARK-3217] Add Guava to classpath when SPARK_PREPEND_CLASSES is set.
Sep 12, 2014
25311c2
[SPARK-3456] YarnAllocator on alpha can lose container requests to RM
tgravescs Sep 13, 2014
71af030
[SPARK-3094] [PySpark] compatitable with PyPy
davies Sep 13, 2014
885d162
[SPARK-3500] [SQL] use JavaSchemaRDD as SchemaRDD._jschema_rdd
davies Sep 13, 2014
6d887db
[SPARK-3515][SQL] Moves test suite setup code to beforeAll rather tha…
liancheng Sep 13, 2014
2584ea5
[SPARK-3469] Make sure all TaskCompletionListener are called even wit…
rxin Sep 13, 2014
e11eeb7
[SQL][Docs] Update SQL programming guide to show the correct default …
yhuai Sep 13, 2014
feaa370
SPARK-3470 [CORE] [STREAMING] Add Closeable / close() to Java context…
srowen Sep 13, 2014
b4dded4
Proper indent for the previous commit.
rxin Sep 13, 2014
a523cea
[SQL] [Docs] typo fixes
nchammas Sep 13, 2014
184cd51
[SPARK-3481][SQL] Removes the evil MINOR HACK
liancheng Sep 13, 2014
7404924
[SPARK-3294][SQL] Eliminates boxing costs from in-memory columnar sto…
liancheng Sep 13, 2014
0f8c4ed
[SQL] Decrease partitions when testing
marmbrus Sep 13, 2014
2aea0da
[SPARK-3030] [PySpark] Reuse Python worker
davies Sep 13, 2014
4e3fbe8
[SPARK-3463] [PySpark] aggregate and show spilled bytes in Python
davies Sep 14, 2014
c243b21
SPARK-3039: Allow spark to be built using avro-mapred for hadoop2
bbossy Sep 15, 2014
f493f79
[SPARK-3452] Maven build should skip publishing artifacts people shou…
ScrapCodes Sep 15, 2014
cc14644
[SPARK-3410] The priority of shutdownhook for ApplicationMaster shoul…
sarutak Sep 15, 2014
fe2b1d6
[SPARK-3425] do not set MaxPermSize for OpenJDK 1.8
Sep 15, 2014
e59fac1
[SPARK-3518] Remove wasted statement in JsonProtocol
sarutak Sep 15, 2014
37d9252
[SPARK-2714] DAGScheduler logs jobid when runJob finishes
YanTangZhai Sep 15, 2014
3b93128
[SPARK-3396][MLLIB] Use SquaredL2Updater in LogisticRegressionWithSGD
BigCrunsh Sep 16, 2014
983d6a9
[MLlib] Update SVD documentation in IndexedRowMatrix
rezazadeh Sep 16, 2014
fdb302f
[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGa…
Sep 16, 2014
da33acb
[SPARK-2951] [PySpark] support unpickle array.array for Python 2.6
davies Sep 16, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-3086] [SPARK-3043] [SPARK-3156] [mllib] DecisionTree aggregati…
…on improvements

Summary:
1. Variable numBins for each feature [SPARK-3043]
2. Reduced data reshaping in aggregation [SPARK-3043]
3. Choose ordering for ordered categorical features adaptively [SPARK-3156]
4. Changed nodes to use 1-indexing [SPARK-3086]
5. Small clean-ups

Note: This PR looks bigger than it is since I moved several functions from inside findBestSplitsPerGroup to outside of it (to make it clear what was being serialized in the aggregation).

Speedups: This update helps most when many features use few bins but a few features use many bins.  Some example results on speedups with 2M examples, 3.5K features (15-worker EC2 cluster):
* Example where old code was reasonably efficient (1/2 continuous, 1/4 binary, 1/4 20-category): 164.813 --> 116.491 sec
* Example where old code wasted many bins (1/10 continuous, 81/100 binary, 9/100 20-category): 128.701 --> 39.334 sec

Details:

(1) Variable numBins for each feature [SPARK-3043]

DecisionTreeMetadata now computes a variable numBins for each feature.  It also tracks numSplits.

(2) Reduced data reshaping in aggregation [SPARK-3043]

Added DTStatsAggregator, a wrapper around the aggregate statistics array for easy but efficient indexing.
* Added ImpurityAggregator and ImpurityCalculator classes, to make DecisionTree code more oblivious to the type of impurity.
* Design note: I originally tried creating Impurity classes which stored data and storing the aggregates in an Array[Array[Array[Impurity]]].  However, this led to significant slowdowns, perhaps because of overhead in creating so many objects.

The aggregate statistics are never reshaped, and cumulative sums are computed in-place.

Updated the layout of aggregation functions.  The update simplifies things by (1) dividing features into ordered/unordered (instead of ordered/unordered/continuous) and (2) making use of the DTStatsAggregator for indexing.
For this update, the following functions were refactored:
* updateBinForOrderedFeature
* updateBinForUnorderedFeature
* binaryOrNotCategoricalBinSeqOp
* multiclassWithCategoricalBinSeqOp
* regressionBinSeqOp
The above 5 functions were replaced with:
* orderedBinSeqOp
* someUnorderedBinSeqOp

Other changes:
* calculateGainForSplit now treats all feature types the same way.
* Eliminated extractLeftRightNodeAggregates.

(3) Choose ordering for ordered categorical features adaptively [SPARK-3156]

Updated binsToBestSplit():
* This now computes cumulative sums of stats for ordered features.
* For ordered categorical features, it chooses an ordering for categories. (This uses to be done by findSplitsBins.)
* Uses iterators to shorten code and avoid building an Array[Array[InformationGainStats]].

Side effects:
* In findSplitsBins: A sample of the data is only taken for data with continuous features.  It is not needed for data with only categorical features.
* In findSplitsBins: splits and bins are no longer pre-computed for ordered categorical features since they are not needed.
* TreePoint binning is simpler for categorical features.

(4) Changed nodes to use 1-indexing [SPARK-3086]

Nodes used to be indexed from 0.  Now they are indexed from 1.
Node indexing functions are now collected in object Node (Node.scala).

(5) Small clean-ups

Eliminated functions extractNodeInfo() and extractInfoForLowerLevels() to reduce duplicate code.
Eliminated InvalidBinIndex since it is no longer used.

CC: mengxr  manishamde  Please let me know if you have thoughts on this—thanks!

Author: Joseph K. Bradley <[email protected]>

Closes apache#2125 from jkbradley/dt-opt3alt and squashes the following commits:

42c192a [Joseph K. Bradley] Merge branch 'rfs' into dt-opt3alt
d3cc46b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt
00e4404 [Joseph K. Bradley] optimization for TreePoint construction (pre-computing featureArity and isUnordered as arrays)
425716c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs
a2acea5 [Joseph K. Bradley] Small optimizations based on profiling
aa4e4df [Joseph K. Bradley] Updated DTStatsAggregator with bug fix (nodeString should not be multiplied by statsSize)
4651154 [Joseph K. Bradley] Changed numBins semantics for unordered features. * Before: numBins = numSplits = (1 << k - 1) - 1 * Now: numBins = 2 * numSplits = 2 * [(1 << k - 1) - 1] * This also involved changing the semantics of: ** DecisionTreeMetadata.numUnorderedBins()
1e3b1c7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt
1485fcc [Joseph K. Bradley] Made some DecisionTree methods private.
92f934f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt
e676da1 [Joseph K. Bradley] Updated documentation for DecisionTree
37ca845 [Joseph K. Bradley] Fixed problem with how DecisionTree handles ordered categorical	features.
105f8ab [Joseph K. Bradley] Removed commented-out getEmptyBinAggregates from DecisionTree
062c31d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt
6d32ccd [Joseph K. Bradley] In DecisionTree.binsToBestSplit, changed loops to iterators to shorten code.
807cd00 [Joseph K. Bradley] Finished DTStatsAggregator, a wrapper around the aggregate statistics for easy but hopefully efficient indexing.  Modified old ImpurityAggregator classes and renamed them ImpurityCalculator; added ImpurityAggregator classes which work with DTStatsAggregator but do not store data.  Unit tests all succeed.
f2166fd [Joseph K. Bradley] still working on DTStatsAggregator
92f7118 [Joseph K. Bradley] Added partly written DTStatsAggregator
fd8df30 [Joseph K. Bradley] Moved some aggregation helpers outside of findBestSplitsPerGroup
d7c53ee [Joseph K. Bradley] Added more doc for ImpurityAggregator
a40f8f1 [Joseph K. Bradley] Changed nodes to be indexed from 1.  Tests work.
95cad7c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3
5f94342 [Joseph K. Bradley] Added treeAggregate since not yet merged from master.  Moved node indexing functions to Node.
61c4509 [Joseph K. Bradley] Fixed bugs from merge: missing DT timer call, and numBins setting.  Cleaned up DT Suite some.
3ba7166 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3
b314659 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3
9c83363 [Joseph K. Bradley] partial merge but not done yet
45f7ea7 [Joseph K. Bradley] partial merge, not yet done
5fce635 [Joseph K. Bradley] Merge branch 'dt-opt2' into dt-opt3
26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used.  Removed debugging println calls in DecisionTree.scala.
356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2
430d782 [Joseph K. Bradley] Added more debug info on binning error.  Added some docs.
d036089 [Joseph K. Bradley] Print timing info to logDebug.
e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private
8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up.  Removed debugging println calls from DecisionTree.  Made TreePoint extend Serialiable
a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
dd4d3aa [Joseph K. Bradley] Mid-process in bug fix: bug for binary classification with categorical features * Bug: Categorical features were all treated as ordered for binary classification.  This is possible but would require the bin ordering to be determined on-the-fly after the aggregation.  Currently, the ordering is determined a priori and fixed for all splits. * (Temp) Fix: Treat low-arity categorical features as unordered for binary classification. * Related change: I removed most tests for isMulticlass in the code.  I instead test metadata for whether there are unordered features. * Status: The bug may be fixed, but more testing needs to be done.
438a660 [Joseph K. Bradley] removed subsampling for mnist8m from DT
86e217f [Joseph K. Bradley] added cache to DT input
e3c84cc [Joseph K. Bradley] Added stuff fro mnist8m to D T Runner
51ef781 [Joseph K. Bradley] Fixed bug introduced by last commit: Variance impurity calculation was incorrect since counts were swapped accidentally
fd65372 [Joseph K. Bradley] Major changes: * Created ImpurityAggregator classes, rather than old aggregates. * Feature split/bin semantics are based on ordered vs. unordered ** E.g.: numSplits = numBins for all unordered features, and numSplits = numBins - 1 for all ordered features. * numBins can differ for each feature
c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification
b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes
b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt
0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree
3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging)
f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
  • Loading branch information
jkbradley authored and mengxr committed Sep 8, 2014
commit 711356b422c66e2a80377a9f43fce97282460520
1,341 changes: 450 additions & 891 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impl

import org.apache.spark.mllib.tree.impurity._

/**
* DecisionTree statistics aggregator.
* This holds a flat array of statistics for a set of (nodes, features, bins)
* and helps with indexing.
*/
private[tree] class DTStatsAggregator(
val metadata: DecisionTreeMetadata,
val numNodes: Int) extends Serializable {

/**
* [[ImpurityAggregator]] instance specifying the impurity type.
*/
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
case Gini => new GiniAggregator(metadata.numClasses)
case Entropy => new EntropyAggregator(metadata.numClasses)
case Variance => new VarianceAggregator()
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}

/**
* Number of elements (Double values) used for the sufficient statistics of each bin.
*/
val statsSize: Int = impurityAggregator.statsSize

val numFeatures: Int = metadata.numFeatures

/**
* Number of bins for each feature. This is indexed by the feature index.
*/
val numBins: Array[Int] = metadata.numBins

/**
* Number of splits for the given feature.
*/
def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)

/**
* Indicator for each feature of whether that feature is an unordered feature.
* TODO: Is Array[Boolean] any faster?
*/
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)

/**
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
if (isUnordered(featureIndex)) {
total + 2 * numBins(featureIndex)
} else {
total + numBins(featureIndex)
}
}
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
}

/**
* Number of elements for each node, corresponding to stride between nodes in [[allStats]].
*/
private val nodeStride: Int = featureOffsets.last

/**
* Total number of elements stored in this aggregator.
*/
val allStatsSize: Int = numNodes * nodeStride

/**
* Flat array of elements.
* Index for start of stats for a (node, feature, bin) is:
* index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
* Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex))
* and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex))
*/
val allStats: Array[Double] = new Array[Double](allStatsSize)

/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
*/
def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
}

/**
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
*/
def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label)
}

/**
* Pre-compute node offset for use with [[nodeUpdate]].
*/
def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride

/**
* Faster version of [[update]].
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
* @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
*/
def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label)
}

/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
* For ordered features only.
*/
def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
require(!isUnordered(featureIndex),
s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" +
s" for unordered feature $featureIndex.")
nodeIndex * nodeStride + featureOffsets(featureIndex)
}

/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
* For unordered features only.
*/
def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
require(isUnordered(featureIndex),
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
(baseOffset, baseOffset + numBins(featureIndex) * statsSize)
}

/**
* Faster version of [[update]].
* Update the stats for a given (node, feature, bin), using the given label.
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
*/
def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = {
impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label)
}

/**
* For a given (node, feature), merge the stats for two bins.
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
nodeFeatureOffset + otherBinIndex * statsSize)
}

/**
* Merge this aggregator with another, and returns this aggregator.
* This method modifies this aggregator in-place.
*/
def merge(other: DTStatsAggregator): DTStatsAggregator = {
require(allStatsSize == other.allStatsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+ s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
var i = 0
// TODO: Test BLAS.axpy
while (i < allStatsSize) {
allStats(i) += other.allStats(i)
i += 1
}
this
}

}

private[tree] object DTStatsAggregator extends Serializable {

/**
* Combines two aggregates (modifying the first) and returns the combination.
*/
def binCombOp(
agg1: DTStatsAggregator,
agg2: DTStatsAggregator): DTStatsAggregator = {
agg1.merge(agg2)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.rdd.RDD


/**
* Learning and dataset metadata for DecisionTree.
*
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
* @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
* @param numBins Number of bins for each feature.
*/
private[tree] class DecisionTreeMetadata(
val numFeatures: Int,
Expand All @@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata(
val maxBins: Int,
val featureArity: Map[Int, Int],
val unorderedFeatures: Set[Int],
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {

Expand All @@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata(

def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)

/**
* Number of splits for the given feature.
* For unordered features, there are 2 bins per split.
* For ordered features, there is 1 more bin than split.
*/
def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
numBins(featureIndex) >> 1
} else {
numBins(featureIndex) - 1
}

}

private[tree] object DecisionTreeMetadata {

/**
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
* This computes which categorical features will be ordered vs. unordered,
* as well as the number of splits and bins for each feature.
*/
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {

val numFeatures = input.take(1)(0).features.size
Expand All @@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata {
case Regression => 0
}

val maxBins = math.min(strategy.maxBins, numExamples).toInt
val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt

// We check the number of bins here against maxPossibleBins.
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
// based on the number of training examples.
if (strategy.categoricalFeaturesInfo.nonEmpty) {
val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
require(maxCategoriesPerFeature <= maxPossibleBins,
s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
s"in categorical features (= $maxCategoriesPerFeature)")
}

val unorderedFeatures = new mutable.HashSet[Int]()
val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
if (numClasses > 2) {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
if (k - 1 < log2MaxBinsp1) {
// Note: The above check is equivalent to checking:
// numUnorderedBins = (1 << k - 1) - 1 < maxBins
unorderedFeatures.add(f)
// Multiclass classification
val maxCategoriesForUnorderedFeature =
((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
// Decide if some categorical features should be treated as unordered features,
// which require 2 * ((1 << numCategories - 1) - 1) bins.
// We do this check with log values to prevent overflows in case numCategories is large.
// The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
if (numCategories <= maxCategoriesForUnorderedFeature) {
unorderedFeatures.add(featureIndex)
numBins(featureIndex) = numUnorderedBins(numCategories)
} else {
// TODO: Allow this case, where we simply will know nothing about some categories?
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
s"in categorical features (>= $k)")
numBins(featureIndex) = numCategories
}
}
} else {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
s"in categorical features (>= $k)")
// Binary classification or regression
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
numBins(featureIndex) = numCategories
}
}

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy)
}

/**
* Given the arity of a categorical feature (arity = number of categories),
* return the number of bins for the feature if it is to be treated as an unordered feature.
* There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
* there are math.pow(2, arity - 1) - 1 such splits.
* Each split has 2 corresponding bins.
*/
def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)

}
Loading