File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed
mllib/src/main/scala/org/apache/spark/mllib/tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -816,7 +816,15 @@ object DecisionTree extends Serializable with Logging {
816816
817817 val maxBins = strategy.maxBins
818818 val numBins = if (maxBins <= count) maxBins else count.toInt
819- logDebug(" maxBins = " + numBins)
819+ logDebug(" numBins = " + numBins)
820+
821+ // I will also add a require statement ensuring #bins is always greater than the categories
822+ // It's a limitation of the current implementation but a reasonable tradeoff since features
823+ // with large number of categories get favored over continuous features.
824+ if (strategy.categoricalFeaturesInfo.size > 0 ){
825+ val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
826+ require(numBins >= maxCategoriesForFeatures)
827+ }
820828
821829 // Calculate the number of sample for approximate quantile calculation
822830 val requiredSamples = numBins* numBins
You can’t perform that action at this time.
0 commit comments