Skip to content

Commit b6cde56

Browse files
committed
update
1 parent ce95023 commit b6cde56

File tree

2 files changed

+50
-38
lines changed

2 files changed

+50
-38
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ class LogisticRegression @Since("1.2.0") (
884884
numClasses, isMultinomial))
885885

886886
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
887-
val logRegSummary = if (numClasses <=2) {
887+
val logRegSummary = if (numClasses <= 2) {
888888
new BinaryLogisticRegressionTrainingSummaryImpl(
889889
summaryModel.transform(dataset),
890890
probabilityColName,
@@ -1017,15 +1017,19 @@ class LogisticRegressionModel private[spark] (
10171017
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
10181018

10191019
/**
1020-
* Gets summary of model on training set. An exception is
1021-
* thrown if `trainingSummary == None`.
1020+
* Gets summary of model on training set. An exception is thrown
1021+
* if `trainingSummary == None`.
10221022
*/
10231023
@Since("1.5.0")
10241024
def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
10251025
throw new SparkException("No training summary available for this LogisticRegressionModel")
10261026
}
10271027

1028-
@Since("2.2.0")
1028+
/**
1029+
* Gets summary of model on training set. An exception is thrown
1030+
* if `trainingSummary == None` or it is a multiclass model.
1031+
*/
1032+
@Since("2.3.0")
10291033
def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match {
10301034
case b: BinaryLogisticRegressionTrainingSummary => b
10311035
case _ =>
@@ -1357,23 +1361,23 @@ sealed trait LogisticRegressionSummary extends Serializable {
13571361
/**
13581362
* Dataframe output by the model's `transform` method.
13591363
*/
1360-
@Since("2.3.0")
1364+
@Since("1.5.0")
13611365
def predictions: DataFrame
13621366

13631367
/** Field in "predictions" which gives the probability of each class as a vector. */
1364-
@Since("2.3.0")
1368+
@Since("1.5.0")
13651369
def probabilityCol: String
13661370

13671371
/** Field in "predictions" which gives the prediction of each class. */
13681372
@Since("2.3.0")
13691373
def predictionCol: String
13701374

13711375
/** Field in "predictions" which gives the true label of each instance (if available). */
1372-
@Since("2.3.0")
1376+
@Since("1.5.0")
13731377
def labelCol: String
13741378

13751379
/** Field in "predictions" which gives the features of each instance as a vector. */
1376-
@Since("2.3.0")
1380+
@Since("1.6.0")
13771381
def featuresCol: String
13781382

13791383
@transient private val multiclassMetrics = {
@@ -1384,6 +1388,17 @@ sealed trait LogisticRegressionSummary extends Serializable {
13841388
.rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) })
13851389
}
13861390

1391+
/**
1392+
* Returns the sequence of labels in ascending order
1393+
*
1394+
* Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
1395+
* training set is missing a label, then all of the arrays over labels
1396+
* (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
1397+
* expected numClasses.
1398+
*/
1399+
@Since("2.3.0")
1400+
def labels: Array[Double] = multiclassMetrics.labels
1401+
13871402
/** Returns true positive rate for each label (category). */
13881403
@Since("2.3.0")
13891404
def truePositiveRateByLabel: Array[Double] = recallByLabel
@@ -1561,7 +1576,6 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
15611576
with LogisticRegressionTrainingSummary
15621577

15631578
/**
1564-
* :: Experimental ::
15651579
* Multiclass logistic regression training results.
15661580
*
15671581
* @param predictions dataframe output by the model's `transform` method.
@@ -1574,18 +1588,17 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
15741588
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
15751589
*/
15761590
private class LogisticRegressionTrainingSummaryImpl(
1577-
override val predictions: DataFrame,
1578-
override val probabilityCol: String,
1579-
override val predictionCol: String,
1580-
override val labelCol: String,
1581-
override val featuresCol: String,
1582-
val objectiveHistory: Array[Double])
1591+
predictions: DataFrame,
1592+
probabilityCol: String,
1593+
predictionCol: String,
1594+
labelCol: String,
1595+
featuresCol: String,
1596+
override val objectiveHistory: Array[Double])
15831597
extends LogisticRegressionSummaryImpl(
15841598
predictions, probabilityCol, predictionCol, labelCol, featuresCol)
15851599
with LogisticRegressionTrainingSummary
15861600

15871601
/**
1588-
* :: Experimental ::
15891602
* Multiclass logistic regression results for a given model.
15901603
*
15911604
* @param predictions dataframe output by the model's `transform` method.
@@ -1605,7 +1618,6 @@ private class LogisticRegressionSummaryImpl(
16051618
extends LogisticRegressionSummary
16061619

16071620
/**
1608-
* :: Experimental ::
16091621
* Binary logistic regression training results.
16101622
*
16111623
* @param predictions dataframe output by the model's `transform` method.
@@ -1618,18 +1630,17 @@ private class LogisticRegressionSummaryImpl(
16181630
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
16191631
*/
16201632
private class BinaryLogisticRegressionTrainingSummaryImpl(
1621-
override val predictions: DataFrame,
1622-
override val probabilityCol: String,
1623-
override val predictionCol: String,
1624-
override val labelCol: String,
1625-
override val featuresCol: String,
1633+
predictions: DataFrame,
1634+
probabilityCol: String,
1635+
predictionCol: String,
1636+
labelCol: String,
1637+
featuresCol: String,
16261638
override val objectiveHistory: Array[Double])
16271639
extends BinaryLogisticRegressionSummaryImpl(
16281640
predictions, probabilityCol, predictionCol, labelCol, featuresCol)
16291641
with BinaryLogisticRegressionTrainingSummary
16301642

16311643
/**
1632-
* :: Experimental ::
16331644
* Binary logistic regression results for a given model.
16341645
*
16351646
* @param predictions dataframe output by the model's `transform` method.
@@ -1641,11 +1652,11 @@ private class BinaryLogisticRegressionTrainingSummaryImpl(
16411652
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
16421653
*/
16431654
private class BinaryLogisticRegressionSummaryImpl(
1644-
@transient override val predictions: DataFrame,
1645-
override val probabilityCol: String,
1646-
override val predictionCol: String,
1647-
override val labelCol: String,
1648-
override val featuresCol: String)
1655+
predictions: DataFrame,
1656+
probabilityCol: String,
1657+
predictionCol: String,
1658+
labelCol: String,
1659+
featuresCol: String)
16491660
extends LogisticRegressionSummaryImpl(
16501661
predictions, probabilityCol, predictionCol, labelCol, featuresCol)
16511662
with BinaryLogisticRegressionSummary

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -213,43 +213,44 @@ class LogisticRegressionSuite
213213
case (family, dataset) =>
214214
lr.setFamily(family)
215215
lr.setProbabilityCol("").setPredictionCol("prediction")
216-
val modelNoProb = lr.fit(smallBinaryDataset)
216+
val modelNoProb = lr.fit(dataset)
217217
checkSummarySchema(modelNoProb, Seq("probability_"))
218218

219219
lr.setProbabilityCol("probability").setPredictionCol("")
220-
val modelNoPred = lr.fit(smallBinaryDataset)
220+
val modelNoPred = lr.fit(dataset)
221221
checkSummarySchema(modelNoPred, Seq("prediction_"))
222222

223223
lr.setProbabilityCol("").setPredictionCol("")
224-
val modelNoPredNoProb = lr.fit(smallBinaryDataset)
224+
val modelNoPredNoProb = lr.fit(dataset)
225225
checkSummarySchema(modelNoPredNoProb, Seq("prediction_", "probability_"))
226226
}
227227
}
228228

229229
test("check summary types for binary and multiclass") {
230230
val lr = new LogisticRegression()
231231
.setFamily("binomial")
232+
.setMaxIter(1)
232233

233234
val blorModel = lr.fit(smallBinaryDataset)
234-
assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
235-
assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
235+
assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
236+
assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
236237

237238
val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
238-
assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummaryImpl])
239+
assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary])
239240
withClue("cannot get binary summary for multiclass model") {
240241
intercept[RuntimeException] {
241242
mlorModel.binarySummary
242243
}
243244
}
244245

245246
val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
246-
assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
247-
assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
247+
assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
248+
assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
248249

249250
val blorSummary = blorModel.evaluate(smallBinaryDataset)
250251
val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
251-
assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummaryImpl])
252-
assert(mlorSummary.isInstanceOf[LogisticRegressionSummaryImpl])
252+
assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummary])
253+
assert(mlorSummary.isInstanceOf[LogisticRegressionSummary])
253254
}
254255

255256
test("setThreshold, getThreshold") {

0 commit comments

Comments
 (0)