Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
08831e7
[SPARK-14975][ML][WIP] Fixed GBTClassifier to predict probability per…
imatiach-msft Dec 30, 2016
e73b60f
Fixed scala style empty line
imatiach-msft Dec 30, 2016
d29b70d
Fixed binary compatibility tests
imatiach-msft Dec 30, 2016
d4afdd0
Fixing GBT classifier based on comments
imatiach-msft Jan 3, 2017
62702c8
Fixing probabilities calculated from raw scores
imatiach-msft Jan 5, 2017
27882b3
fixed scala style, multiplied raw prediction value by 2 in prob estimate
imatiach-msft Jan 5, 2017
8698d16
Updating based on code review, including code cleanup and adding bett…
imatiach-msft Jan 6, 2017
aaf1b06
Adding back constructor but making it private
imatiach-msft Jan 6, 2017
bafab79
updates to GBTClassifier based on comments
imatiach-msft Jan 10, 2017
2a6dea4
minor fixes to scala style
imatiach-msft Jan 10, 2017
52c5115
Fixing more scala style
imatiach-msft Jan 10, 2017
609a1b0
Using getOldLossType as per comments
imatiach-msft Jan 10, 2017
a28afe6
Added more tests for thresholds, fixed minor bug in predict to use th…
imatiach-msft Jan 10, 2017
9d5bb9b
Updated based on newest comments
imatiach-msft Jan 10, 2017
89965f5
missed one arg
imatiach-msft Jan 10, 2017
cacbbc1
Moving arg to its own line
imatiach-msft Jan 10, 2017
7396dac
Updated based on latest comments - moved classifier loss trait, updat…
imatiach-msft Jan 11, 2017
f2e041d
Fixed up minor comments
imatiach-msft Jan 11, 2017
1abfee0
Updated based on comments from jkbradley
imatiach-msft Jan 18, 2017
818de81
Fixing build issues - need to keep numClasses in model
imatiach-msft Jan 18, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Using getOldLossType as per comments
  • Loading branch information
imatiach-msft committed Jan 18, 2017
commit 609a1b0a29c9835f33196ec56736736e60acf232
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class GBTClassificationModel private[ml](
// and negative result:
// p-(x) = 1 / (1 + e^(2 * F(x)))
case dv: DenseVector =>
dv.values(0) = classProbability(getLossType, dv.values(0))
dv.values(0) = getOldLossType.computeProbability(dv.values(0))
dv.values(1) = 1.0 - dv.values(0)
dv
case sv: SparseVector =>
Expand Down Expand Up @@ -315,13 +315,6 @@ class GBTClassificationModel private[ml](
@Since("2.0.0")
lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)

private def classProbability(loss: String, rawPrediction: Double): Double = {
loss match {
case "logistic" => LogLoss.computeProbability(rawPrediction)
case _ => throw new Exception("Only logistic loss is supported ...")
}
}

private def margin(features: Vector): Double = {
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

/**
Expand Down Expand Up @@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam
def getLossType: String = $(lossType).toLowerCase

/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
override private[ml] def getOldLossType: OldClassificationLoss = {
getLossType match {
case "logistic" => OldLogLoss
case _ =>
Expand Down