Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
08831e7
[SPARK-14975][ML][WIP] Fixed GBTClassifier to predict probability per…
imatiach-msft Dec 30, 2016
e73b60f
Fixed scala style empty line
imatiach-msft Dec 30, 2016
d29b70d
Fixed binary compatibility tests
imatiach-msft Dec 30, 2016
d4afdd0
Fixing GBT classifier based on comments
imatiach-msft Jan 3, 2017
62702c8
Fixing probabilities calculated from raw scores
imatiach-msft Jan 5, 2017
27882b3
fixed scala style, multiplied raw prediction value by 2 in prob estimate
imatiach-msft Jan 5, 2017
8698d16
Updating based on code review, including code cleanup and adding bett…
imatiach-msft Jan 6, 2017
aaf1b06
Adding back constructor but making it private
imatiach-msft Jan 6, 2017
bafab79
updates to GBTClassifier based on comments
imatiach-msft Jan 10, 2017
2a6dea4
minor fixes to scala style
imatiach-msft Jan 10, 2017
52c5115
Fixing more scala style
imatiach-msft Jan 10, 2017
609a1b0
Using getOldLossType as per comments
imatiach-msft Jan 10, 2017
a28afe6
Added more tests for thresholds, fixed minor bug in predict to use th…
imatiach-msft Jan 10, 2017
9d5bb9b
Updated based on newest comments
imatiach-msft Jan 10, 2017
89965f5
missed one arg
imatiach-msft Jan 10, 2017
cacbbc1
Moving arg to its own line
imatiach-msft Jan 10, 2017
7396dac
Updated based on latest comments - moved classifier loss trait, updat…
imatiach-msft Jan 11, 2017
f2e041d
Fixed up minor comments
imatiach-msft Jan 11, 2017
1abfee0
Updated based on comments from jkbradley
imatiach-msft Jan 18, 2017
818de81
Fixing build issues - need to keep numClasses in model
imatiach-msft Jan 18, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Updated based on newest comments
  • Loading branch information
imatiach-msft committed Jan 18, 2017
commit 9d5bb9b598903583c95b4de3142d23106c971e55
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class GBTClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)

val numClasses: Int = getNumClasses(dataset)
val numClasses: Int = 2
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
Expand Down Expand Up @@ -229,8 +229,9 @@ class GBTClassificationModel private[ml](
* @param numFeatures The number of features.
*/
private[ml] def this(uid: String, _trees: Array[DecisionTreeRegressionModel],
_treeWeights: Array[Double], numFeatures: Int) =
this(uid, _trees, _treeWeights, numFeatures, 2)
_treeWeights: Array[Double],
numFeatures: Int) =
this(uid, _trees, _treeWeights, numFeatures, 2)

/**
* Construct a GBTClassificationModel
Expand All @@ -240,7 +241,7 @@ class GBTClassificationModel private[ml](
*/
@Since("1.6.0")
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
this(uid, _trees, _treeWeights, -1, 2)
this(uid, _trees, _treeWeights, -1, 2)

@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
Expand All @@ -267,8 +268,7 @@ class GBTClassificationModel private[ml](
if (isDefined(thresholds)) {
super.predict(features)
} else {
val prediction: Double = margin(features)
if (prediction > 0.0) 1.0 else 0.0
if (margin(features) > 0.0) 1.0 else 0.0
}
}

Expand All @@ -279,12 +279,8 @@ class GBTClassificationModel private[ml](

override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
// The probability can be calculated for positive result:
// p+(x) = 1 / (1 + e^(-2 * F(x)))
// and negative result:
// p-(x) = 1 / (1 + e^(2 * F(x)))
case dv: DenseVector =>
dv.values(0) = getOldLossType.computeProbability(dv.values(0))
dv.values(0) = loss.computeProbability(dv.values(0))
dv.values(1) = 1.0 - dv.values(0)
dv
case sv: SparseVector =>
Expand Down Expand Up @@ -330,6 +326,12 @@ class GBTClassificationModel private[ml](
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}

/**
* Note: this is currently an optimization that should be removed when we have more loss
* functions available than only logistic.
*/
private lazy val loss = getOldLossType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why lazy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed lazy, removed comment. I made it lazy so as to not do the lookup if it doesn't need to be done, but since that isn't actually expensive and that only seemed to confuse it's better to remove it.


@Since("2.0.0")
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@ import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.util.MLUtils

/**
* :: DeveloperApi ::
* Trait for adding "pluggable" probability function for the gradient boosting algorithm.
* Trait for adding probability function for the gradient boosting algorithm.
*/
@Since("1.2.0")
@DeveloperApi
trait ClassificationLoss extends Loss {
private[spark] trait ClassificationLoss extends Loss {
private[spark] def computeProbability(prediction: Double): Double
}

Expand Down Expand Up @@ -63,6 +60,8 @@ object LogLoss extends ClassificationLoss {
}

override private[spark] def computeProbability(prediction: Double): Double = {
1 / (1 + math.exp(-2 * prediction))
// The probability can be calculated as:
// p+(x) = 1 / (1 + e^(-2 * F(x)))
1.0 / (1.0 + math.exp(-2.0 * prediction))
}
}