Skip to content
Closed
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
08831e7
[SPARK-14975][ML][WIP] Fixed GBTClassifier to predict probability per…
imatiach-msft Dec 30, 2016
e73b60f
Fixed scala style empty line
imatiach-msft Dec 30, 2016
d29b70d
Fixed binary compatibility tests
imatiach-msft Dec 30, 2016
d4afdd0
Fixing GBT classifier based on comments
imatiach-msft Jan 3, 2017
62702c8
Fixing probabilities calculated from raw scores
imatiach-msft Jan 5, 2017
27882b3
fixed scala style, multiplied raw prediction value by 2 in prob estimate
imatiach-msft Jan 5, 2017
8698d16
Updating based on code review, including code cleanup and adding bett…
imatiach-msft Jan 6, 2017
aaf1b06
Adding back constructor but making it private
imatiach-msft Jan 6, 2017
bafab79
updates to GBTClassifier based on comments
imatiach-msft Jan 10, 2017
2a6dea4
minor fixes to scala style
imatiach-msft Jan 10, 2017
52c5115
Fixing more scala style
imatiach-msft Jan 10, 2017
609a1b0
Using getOldLossType as per comments
imatiach-msft Jan 10, 2017
a28afe6
Added more tests for thresholds, fixed minor bug in predict to use th…
imatiach-msft Jan 10, 2017
9d5bb9b
Updated based on newest comments
imatiach-msft Jan 10, 2017
89965f5
missed one arg
imatiach-msft Jan 10, 2017
cacbbc1
Moving arg to its own line
imatiach-msft Jan 10, 2017
7396dac
Updated based on latest comments - moved classifier loss trait, updat…
imatiach-msft Jan 11, 2017
f2e041d
Fixed up minor comments
imatiach-msft Jan 11, 2017
1abfee0
Updated based on comments from jkbradley
imatiach-msft Jan 18, 2017
818de81
Fixing build issues - need to keep numClasses in model
imatiach-msft Jan 18, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fixing build issues - need to keep numClasses in model
  • Loading branch information
imatiach-msft committed Jan 18, 2017
commit 818de810cbfdf4fc671f21a262a34cc8554f9af6
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class GBTClassifier @Since("1.4.0") (

val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
$(seed))
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures, numClasses)
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m)
m
}
Expand Down Expand Up @@ -209,7 +209,8 @@ class GBTClassificationModel private[ml](
@Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
@Since("1.6.0") override val numFeatures: Int,
@Since("2.2.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {
Expand All @@ -218,6 +219,20 @@ class GBTClassificationModel private[ml](
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")

/**
* Construct a GBTClassificationModel
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
* @param numFeatures The number of features.
*/
private[ml] def this(
uid: String,
_trees: Array[DecisionTreeRegressionModel],
_treeWeights: Array[Double],
numFeatures: Int) =
this(uid, _trees, _treeWeights, numFeatures, 2)

/**
* Construct a GBTClassificationModel
*
Expand All @@ -226,7 +241,7 @@ class GBTClassificationModel private[ml](
*/
@Since("1.6.0")
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
this(uid, _trees, _treeWeights, -1)
this(uid, _trees, _treeWeights, -1, 2)

@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
Expand Down Expand Up @@ -279,7 +294,7 @@ class GBTClassificationModel private[ml](

@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
extra).setParent(parent)
}

Expand Down Expand Up @@ -377,14 +392,15 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
oldModel: OldGBTModel,
parent: GBTClassifier,
categoricalFeatures: Map[Int, Int],
numFeatures: Int = -1): GBTClassificationModel = {
numFeatures: Int = -1,
numClasses: Int = 2): GBTClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
}
}