Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
first version that pass all run_test, by adding redundant constructor…
… and reverting getOldStrategy to old versions
  • Loading branch information
Yuewei Na authored and Yuewei Na committed Jun 20, 2016
commit 455c47e274e1dff50268a6c07b2f0a67c32ac24c
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
categoricalFeatures: Map[Int, Int],
numClasses: Int): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
subsamplingRate = 1.0, getClassWeights)
subsamplingRate = 1.0)
}

@Since("1.4.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class RandomForestClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses,
OldAlgo.Classification, getOldImpurity, getClassWeights)
OldAlgo.Classification, getOldImpurity)

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
subsamplingRate = 1.0, classWeights = Array())
subsamplingRate = 1.0)
}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0,
OldAlgo.Regression, getOldImpurity, Array())
OldAlgo.Regression, getOldImpurity)

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Strategy @Since("1.3.0") (
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
@Since("2.0.0") var classWeights: Array[Double] = Array(1, 1))
@Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array(1.0, 1.0))
extends Serializable {

/**
Expand All @@ -101,6 +101,29 @@ class Strategy @Since("1.3.0") (
isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
}

/**
* Make the class compatible with previous versions
*/
@Since("2.0.0")
def this(
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int, Int],
minInstancesPerNode: Int,
minInfoGain: Double,
maxMemoryInMB: Int,
subsamplingRate: Double,
useNodeIdCache: Boolean,
checkpointInterval: Int) {
this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
subsamplingRate, useNodeIdCache, checkpointInterval, Array())
}

/**
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private object RandomForestClassifierSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldStrategy =
rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification,
rf.getOldImpurity, rf.getClassWeights)
rf.getOldImpurity)
val oldModel = OldRandomForest.trainClassifier(
data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy,
rf.getSeed.toInt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldStrategy =
rf.getOldStrategy(categoricalFeatures, numClasses = 0,
OldAlgo.Regression, rf.getOldImpurity, classWeights = Array())
OldAlgo.Regression, rf.getOldImpurity)
val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy,
rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
Expand Down
2 changes: 1 addition & 1 deletion scalastyle-config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ This file is divided into 3 sections:
</check>

<check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
<parameters><parameter name="maxParameters"><![CDATA[10]]></parameter></parameters>
<parameters><parameter name="maxParameters"><![CDATA[15]]></parameter></parameters>
</check>

<check level="error" class="org.scalastyle.scalariform.NoFinalizeChecker" enabled="true"></check>
Expand Down