-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17139][ML] Add model summary for MultinomialLogisticRegression #15435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1290ff8
1727203
a96dc54
3c4b995
deddb00
2bce87b
ce95023
b6cde56
67c57e5
0ebc943
1395de2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -884,7 +884,7 @@ class LogisticRegression @Since("1.2.0") ( | |
| numClasses, isMultinomial)) | ||
|
|
||
| val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() | ||
| val logRegSummary = if (numClasses <=2) { | ||
| val logRegSummary = if (numClasses <= 2) { | ||
| new BinaryLogisticRegressionTrainingSummaryImpl( | ||
| summaryModel.transform(dataset), | ||
| probabilityColName, | ||
|
|
@@ -1017,15 +1017,19 @@ class LogisticRegressionModel private[spark] ( | |
| private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None | ||
|
|
||
| /** | ||
| * Gets summary of model on training set. An exception is | ||
| * thrown if `trainingSummary == None`. | ||
| * Gets summary of model on training set. An exception is thrown | ||
| * if `trainingSummary == None`. | ||
| */ | ||
| @Since("1.5.0") | ||
| def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse { | ||
| throw new SparkException("No training summary available for this LogisticRegressionModel") | ||
| } | ||
|
|
||
| @Since("2.2.0") | ||
| /** | ||
| * Gets summary of model on training set. An exception is thrown | ||
| * if `trainingSummary == None` or it is a multiclass model. | ||
| */ | ||
| @Since("2.3.0") | ||
| def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match { | ||
|
||
| case b: BinaryLogisticRegressionTrainingSummary => b | ||
| case _ => | ||
|
|
@@ -1357,23 +1361,23 @@ sealed trait LogisticRegressionSummary extends Serializable { | |
| /** | ||
| * Dataframe output by the model's `transform` method. | ||
| */ | ||
| @Since("2.3.0") | ||
| @Since("1.5.0") | ||
| def predictions: DataFrame | ||
|
|
||
| /** Field in "predictions" which gives the probability of each class as a vector. */ | ||
| @Since("2.3.0") | ||
| @Since("1.5.0") | ||
| def probabilityCol: String | ||
|
|
||
| /** Field in "predictions" which gives the prediction of each class. */ | ||
| @Since("2.3.0") | ||
| def predictionCol: String | ||
|
|
||
| /** Field in "predictions" which gives the true label of each instance (if available). */ | ||
| @Since("2.3.0") | ||
| @Since("1.5.0") | ||
| def labelCol: String | ||
|
|
||
| /** Field in "predictions" which gives the features of each instance as a vector. */ | ||
| @Since("2.3.0") | ||
| @Since("1.6.0") | ||
|
||
| def featuresCol: String | ||
|
|
||
| @transient private val multiclassMetrics = { | ||
|
||
|
|
@@ -1384,6 +1388,17 @@ sealed trait LogisticRegressionSummary extends Serializable { | |
| .rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) }) | ||
| } | ||
|
|
||
| /** | ||
| * Returns the sequence of labels in ascending order | ||
|
||
| * | ||
| * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the | ||
| * training set is missing a label, then all of the arrays over labels | ||
| * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the | ||
| * expected numClasses. | ||
| */ | ||
| @Since("2.3.0") | ||
| def labels: Array[Double] = multiclassMetrics.labels | ||
|
|
||
| /** Returns true positive rate for each label (category). */ | ||
| @Since("2.3.0") | ||
| def truePositiveRateByLabel: Array[Double] = recallByLabel | ||
|
|
@@ -1561,7 +1576,6 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre | |
| with LogisticRegressionTrainingSummary | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Multiclass logistic regression training results. | ||
| * | ||
| * @param predictions dataframe output by the model's `transform` method. | ||
|
|
@@ -1574,18 +1588,17 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre | |
| * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. | ||
| */ | ||
| private class LogisticRegressionTrainingSummaryImpl( | ||
| override val predictions: DataFrame, | ||
| override val probabilityCol: String, | ||
| override val predictionCol: String, | ||
| override val labelCol: String, | ||
| override val featuresCol: String, | ||
| val objectiveHistory: Array[Double]) | ||
| predictions: DataFrame, | ||
| probabilityCol: String, | ||
| predictionCol: String, | ||
| labelCol: String, | ||
| featuresCol: String, | ||
| override val objectiveHistory: Array[Double]) | ||
|
||
| extends LogisticRegressionSummaryImpl( | ||
| predictions, probabilityCol, predictionCol, labelCol, featuresCol) | ||
| with LogisticRegressionTrainingSummary | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Multiclass logistic regression results for a given model. | ||
| * | ||
| * @param predictions dataframe output by the model's `transform` method. | ||
|
|
@@ -1605,7 +1618,6 @@ private class LogisticRegressionSummaryImpl( | |
| extends LogisticRegressionSummary | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Binary logistic regression training results. | ||
| * | ||
| * @param predictions dataframe output by the model's `transform` method. | ||
|
|
@@ -1618,18 +1630,17 @@ private class LogisticRegressionSummaryImpl( | |
| * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. | ||
| */ | ||
| private class BinaryLogisticRegressionTrainingSummaryImpl( | ||
| override val predictions: DataFrame, | ||
| override val probabilityCol: String, | ||
| override val predictionCol: String, | ||
| override val labelCol: String, | ||
| override val featuresCol: String, | ||
| predictions: DataFrame, | ||
| probabilityCol: String, | ||
| predictionCol: String, | ||
| labelCol: String, | ||
| featuresCol: String, | ||
| override val objectiveHistory: Array[Double]) | ||
| extends BinaryLogisticRegressionSummaryImpl( | ||
| predictions, probabilityCol, predictionCol, labelCol, featuresCol) | ||
| with BinaryLogisticRegressionTrainingSummary | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Binary logistic regression results for a given model. | ||
| * | ||
| * @param predictions dataframe output by the model's `transform` method. | ||
|
|
@@ -1641,11 +1652,11 @@ private class BinaryLogisticRegressionTrainingSummaryImpl( | |
| * @param featuresCol field in "predictions" which gives the features of each instance as a vector. | ||
| */ | ||
| private class BinaryLogisticRegressionSummaryImpl( | ||
| @transient override val predictions: DataFrame, | ||
| override val probabilityCol: String, | ||
| override val predictionCol: String, | ||
| override val labelCol: String, | ||
| override val featuresCol: String) | ||
| predictions: DataFrame, | ||
| probabilityCol: String, | ||
| predictionCol: String, | ||
| labelCol: String, | ||
| featuresCol: String) | ||
| extends LogisticRegressionSummaryImpl( | ||
| predictions, probabilityCol, predictionCol, labelCol, featuresCol) | ||
| with BinaryLogisticRegressionSummary | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -213,43 +213,44 @@ class LogisticRegressionSuite | |
| case (family, dataset) => | ||
|
||
| lr.setFamily(family) | ||
| lr.setProbabilityCol("").setPredictionCol("prediction") | ||
| val modelNoProb = lr.fit(smallBinaryDataset) | ||
| val modelNoProb = lr.fit(dataset) | ||
| checkSummarySchema(modelNoProb, Seq("probability_")) | ||
|
|
||
| lr.setProbabilityCol("probability").setPredictionCol("") | ||
| val modelNoPred = lr.fit(smallBinaryDataset) | ||
| val modelNoPred = lr.fit(dataset) | ||
| checkSummarySchema(modelNoPred, Seq("prediction_")) | ||
|
|
||
| lr.setProbabilityCol("").setPredictionCol("") | ||
| val modelNoPredNoProb = lr.fit(smallBinaryDataset) | ||
| val modelNoPredNoProb = lr.fit(dataset) | ||
| checkSummarySchema(modelNoPredNoProb, Seq("prediction_", "probability_")) | ||
| } | ||
| } | ||
|
|
||
| test("check summary types for binary and multiclass") { | ||
| val lr = new LogisticRegression() | ||
|
||
| .setFamily("binomial") | ||
| .setMaxIter(1) | ||
|
|
||
| val blorModel = lr.fit(smallBinaryDataset) | ||
| assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl]) | ||
| assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl]) | ||
| assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) | ||
| assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) | ||
|
|
||
| val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset) | ||
| assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummaryImpl]) | ||
| assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary]) | ||
| withClue("cannot get binary summary for multiclass model") { | ||
| intercept[RuntimeException] { | ||
| mlorModel.binarySummary | ||
| } | ||
| } | ||
|
|
||
| val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset) | ||
| assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl]) | ||
| assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl]) | ||
| assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) | ||
| assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) | ||
|
|
||
| val blorSummary = blorModel.evaluate(smallBinaryDataset) | ||
| val mlorSummary = mlorModel.evaluate(smallMultinomialDataset) | ||
| assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummaryImpl]) | ||
| assert(mlorSummary.isInstanceOf[LogisticRegressionSummaryImpl]) | ||
| assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummary]) | ||
| assert(mlorSummary.isInstanceOf[LogisticRegressionSummary]) | ||
| } | ||
|
|
||
| test("setThreshold, getThreshold") { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use isMultinomial since it is easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jkbradley The logic here isn't
if (!isMultinomial) ... else ...it is
if (!isMultinomial || (isMultinomial && numClasses <= 2)) ... else ...see testcase
test("check summary types for binary and multiclass")