Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
change getOldImpu interfaces
  • Loading branch information
Yuewei Na authored and Yuewei Na committed Jun 17, 2016
commit fd2eee567deb3308e6184f4ecb20f681f9fa9353
13 changes: 6 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
subsamplingRate: Double,
classWeights: Array[Double]): OldStrategy = {
subsamplingRate: Double): OldStrategy = {
val strategy = OldStrategy.defaultStrategy(oldAlgo)
strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
Expand All @@ -190,7 +189,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
strategy.classWeights = classWeights
strategy.classWeights = getClassWeights
strategy
}
}
Expand Down Expand Up @@ -331,10 +330,9 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
classWeights: Array[Double] = Array()): OldStrategy = {
oldImpurity: OldImpurity): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
oldImpurity, getSubsamplingRate, classWeights)
oldImpurity, getSubsamplingRate)
}
}

Expand Down Expand Up @@ -477,7 +475,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
categoricalFeatures: Map[Int, Int],
oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2,
oldAlgo, OldVariance, classWeights = Array(1.0, 1.0))
oldAlgo, OldVariance)

// NOTE: The old API does not support "seed" so we ignore it.
new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Strategy @Since("1.3.0") (
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
@Since("2.0.0") var classWeights: Array[Int] = Array(1, 1))
@Since("2.0.0") var classWeights: Array[Double] = Array(1, 1))
extends Serializable {

/**
Expand Down