Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
first version that pass all tests
  • Loading branch information
Yuewei Na authored and Yuewei Na committed Jun 21, 2016
commit 9c99973476c9143535a913761ceced3ab1d73541
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
categoricalFeatures: Map[Int, Int],
numClasses: Int): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
subsamplingRate = 1.0)
subsamplingRate = 1.0, getClassWeights)
}

@Since("1.4.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ class RandomForestClassifier @Since("1.4.0") (
val numClasses: Int = getNumClasses(dataset)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses,
OldAlgo.Classification, getOldImpurity)
super.getOldStrategy(categoricalFeatures = categoricalFeatures, numClasses = numClasses,
oldAlgo = OldAlgo.Classification, oldImpurity = getOldImpurity,
subsamplingRate = getSubsamplingRate, classWeights = getClassWeights)

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
subsamplingRate = 1.0)
1.0, Array())
}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0,
OldAlgo.Regression, getOldImpurity)
OldAlgo.Regression, getOldImpurity, getSubsamplingRate,
classWeights = Array())

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand Down
94 changes: 71 additions & 23 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
" trees.")

/**
* An array that stores the weights of class labels. All elements must be non-negative.
* (default = Array(1, 1))
* @group Param
*/
final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
" that stores the weights of class labels. All elements must be non-negative.")

setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10,
classWeights -> Array(1.0, 1.0))
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)

/** @group setParam */
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
Expand Down Expand Up @@ -153,12 +144,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams
/** @group expertGetParam */
final def getCacheNodeIds: Boolean = $(cacheNodeIds)

/** @group SetParam */
def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)

/** @group GetParam */
final def getClassWeights: Array[Double] = $(classWeights)

/**
* Specifies how often to checkpoint the cached node IDs.
* E.g. 10 means that the cache will get checkpointed every 10 iterations.
Expand All @@ -176,7 +161,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
subsamplingRate: Double): OldStrategy = {
subsamplingRate: Double,
classWeights: Array[Double]): OldStrategy = {
val strategy = OldStrategy.defaultStrategy(oldAlgo)
strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
Expand All @@ -189,9 +175,32 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
strategy.classWeights = getClassWeights
strategy.classWeights = classWeights
strategy
}

/** (private[ml]) Create a Strategy whose interface is compatible with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
subsamplingRate: Double): OldStrategy = {
val strategy = OldStrategy.defaultStrategy(oldAlgo)
strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
strategy.maxBins = getMaxBins
strategy.maxDepth = getMaxDepth
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
strategy.classWeights = Array(1.0, 1.0)
strategy
}
}

/**
Expand All @@ -210,14 +219,28 @@ private[ml] trait TreeClassifierParams extends Params {
s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
(value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))

setDefault(impurity -> "gini")
/**
* An array that stores the weights of class labels. All elements must be non-negative.
* (default = Array(1.0, 1.0))
* @group Param
*/
final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
" that stores the weights of class labels. All elements must be non-negative.")

setDefault(impurity -> "gini", classWeights -> Array(1.0, 1.0))

/** @group setParam */
def setImpurity(value: String): this.type = set(impurity, value)

/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase

/** @group SetParam */
def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)

/** @group GetParam */
final def getClassWeights: Array[Double] = $(classWeights)

/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
Expand Down Expand Up @@ -257,14 +280,29 @@ private[ml] trait TreeRegressorParams extends Params {
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
(value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))

setDefault(impurity -> "variance")
/**
* An array that stores the weights of class labels. This parameter will be ignored in
* regression trees.
* (default = Array())
* @group Param
*/
final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
" that stores the weights of class labels. All elements must be non-negative.")

setDefault(impurity -> "variance", classWeights -> Array())

/** @group setParam */
def setImpurity(value: String): this.type = set(impurity, value)

/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase

/** @group SetParam */
def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)

/** @group GetParam */
final def getClassWeights: Array[Double] = $(classWeights)

/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
Expand Down Expand Up @@ -330,9 +368,19 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity): OldStrategy = {
oldImpurity: OldImpurity,
classWeights: Array[Double]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
oldImpurity, getSubsamplingRate)
oldImpurity, getSubsamplingRate, classWeights)
}

private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
oldImpurity, getSubsamplingRate, Array(1.0, 1.0))
}
}

Expand Down Expand Up @@ -475,7 +523,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
categoricalFeatures: Map[Int, Int],
oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2,
oldAlgo, OldVariance)
oldAlgo, OldVariance, Array(1.0, 1.0))

// NOTE: The old API does not support "seed" so we ignore it.
new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class Strategy @Since("1.3.0") (
}

/**
* Make the class compatible with previous versions
* Make the Strategy class compatible with old API
*/
@Since("2.0.0")
def this(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class DecisionTreeClassifierSuite
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////

test("Binary classification with setting explicit uniform class weights") {
test("Binary classification with explicitly setting uniform class weights") {
val dt = new DecisionTreeClassifier()
.setImpurity("WeightedGini")
.setMaxDepth(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private object RandomForestClassifierSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldStrategy =
rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification,
rf.getOldImpurity)
rf.getOldImpurity, rf.getSubsamplingRate, rf.getClassWeights)
val oldModel = OldRandomForest.trainClassifier(
data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy,
rf.getSeed.toInt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldStrategy =
rf.getOldStrategy(categoricalFeatures, numClasses = 0,
OldAlgo.Regression, rf.getOldImpurity)
OldAlgo.Regression, rf.getOldImpurity, rf.getSubsamplingRate,
classWeights = Array())
val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy,
rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
Expand Down