Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
minor modifications for code styling
  • Loading branch information
Yuewei Na authored and Yuewei Na committed Jun 21, 2016
commit f53a2ccf001ec7db54ebff010fa321e3c116d9a5
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ class RandomForestClassifier @Since("1.4.0") (
val numClasses: Int = getNumClasses(dataset)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures = categoricalFeatures, numClasses = numClasses,
oldAlgo = OldAlgo.Classification, oldImpurity = getOldImpurity,
subsamplingRate = getSubsamplingRate, classWeights = getClassWeights)
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification,
getOldImpurity, getSubsamplingRate, getClassWeights)

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
1.0, Array())
subsamplingRate = 1.0, classWeights = Array())
}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0,
OldAlgo.Regression, getOldImpurity, getSubsamplingRate,
classWeights = Array())
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression,
getOldImpurity, getSubsamplingRate, classWeights = Array())

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ private[ml] object DecisionTreeModelReadWrite {
Param.jsonDecode[String](compact(render(impurityJson)))
}

// Get class weights to construct ImpurityCalculator. This value
// is ignored unless the impurity is WeightedGini
val classWeights: Array[Double] = {
val classWeightsJson: JValue = metadata.getParamValue("classWeights")
compact(render(classWeightsJson)).split("\\[|,|\\]")
Expand Down Expand Up @@ -445,6 +447,8 @@ private[ml] object EnsembleModelReadWrite {
Param.jsonDecode[String](compact(render(impurityJson)))
}

// Get class weights to construct ImpurityCalculator. This value
// is ignored unless the impurity is WeightedGini
val classWeights: Array[Double] = {
val classWeightsJson: JValue = metadata.getParamValue("classWeights")
val classWeightsArray = compact(render(classWeightsJson)).split("\\[|,|\\]")
Expand Down
40 changes: 20 additions & 20 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
*/
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

/** (private[ml]) Create a Strategy instance to use with the old API. */
/** (private[ml]) Create a Strategy instance. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
Expand All @@ -181,25 +181,25 @@ private[ml] trait DecisionTreeParams extends PredictorParams

/** (private[ml]) Create a Strategy whose interface is compatible with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
subsamplingRate: Double): OldStrategy = {
val strategy = OldStrategy.defaultStrategy(oldAlgo)
strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
strategy.maxBins = getMaxBins
strategy.maxDepth = getMaxDepth
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
strategy.classWeights = Array(1.0, 1.0)
strategy
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
subsamplingRate: Double): OldStrategy = {
val strategy = OldStrategy.defaultStrategy(oldAlgo)
strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
strategy.maxBins = getMaxBins
strategy.maxDepth = getMaxDepth
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
strategy.classWeights = Array(1.0, 1.0)
strategy
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}

/**
* :: Experimental ::
* Class for calculating the
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
* during binary classification.
* Class for calculating the Gini impurity with class weights using
* altered prior method during classification.
*/
@Since("1.0.0")
@Since("2.0.0")
@Experimental
object WeightedGini extends Impurity {

Expand All @@ -36,7 +35,7 @@ object WeightedGini extends Impurity {
* @param weightedTotalCount sum of counts for all labels
* @return information value, or 0 if totalCount = 0
*/
@Since("1.1.0")
@Since("2.0.0")
@DeveloperApi
override def calculate(weightedCounts: Array[Double], weightedTotalCount: Double): Double = {
if (weightedTotalCount == 0) {
Expand All @@ -61,7 +60,7 @@ object WeightedGini extends Impurity {
* @param sumSquares summation of squares of the labels
* @return information value, or 0 if count = 0
*/
@Since("1.0.0")
@Since("2.0.0")
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("WeightedGini.calculate")
Expand All @@ -70,7 +69,7 @@ object WeightedGini extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
@Since("1.1.0")
@Since("2.0.0")
def instance: this.type = this

}
Expand Down Expand Up @@ -179,10 +178,10 @@ private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights:
s"Two ImpurityCalculator instances cannot be added with different counts sizes." +
s" Sizes are ${stats.length} and ${other.stats.length}.")
val otherCalculator = other.asInstanceOf[WeightedGiniCalculator]
val len = otherCalculator.stats.length
var i = 0
val len = other.stats.length
while (i < len) {
stats(i) += other.stats(i)
stats(i) += otherCalculator.stats(i)
weightedStats(i) += otherCalculator.weightedStats(i)
i += 1
}
Expand All @@ -198,10 +197,10 @@ private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights:
s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." +
s" Sizes are ${stats.length} and ${other.stats.length}.")
val otherCalculator = other.asInstanceOf[WeightedGiniCalculator]
val len = otherCalculator.stats.length
var i = 0
val len = other.stats.length
while (i < len) {
stats(i) -= other.stats(i)
stats(i) -= otherCalculator.stats(i)
weightedStats(i) -= otherCalculator.weightedStats(i)
i += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}




class DecisionTreeClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.ml.classification

import scala.io.Source

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
Expand All @@ -34,7 +32,6 @@ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}


/**
* Test suite for [[RandomForestClassifier]].
*/
Expand Down Expand Up @@ -213,7 +210,7 @@ class RandomForestClassifierSuite
assert(model.numClasses === model2.numClasses)
}

val rf = new RandomForestClassifier().setNumTrees(2).setClassWeights(Array())
val rf = new RandomForestClassifier().setNumTrees(2)
val rdd = TreeTests.getTreeReadWriteData(sc)

val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
Expand Down