@@ -884,7 +884,7 @@ class LogisticRegression @Since("1.2.0") (
884884 numClasses, isMultinomial))
885885
886886 val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
887- val logRegSummary = if (numClasses <= 2 ) {
887+ val logRegSummary = if (numClasses <= 2 ) {
888888 new BinaryLogisticRegressionTrainingSummaryImpl (
889889 summaryModel.transform(dataset),
890890 probabilityColName,
@@ -1017,15 +1017,19 @@ class LogisticRegressionModel private[spark] (
10171017 private var trainingSummary : Option [LogisticRegressionTrainingSummary ] = None
10181018
10191019 /**
1020- * Gets summary of model on training set. An exception is
1021- * thrown if `trainingSummary == None`.
1020+ * Gets summary of model on training set. An exception is thrown
1021+ * if `trainingSummary == None`.
10221022 */
10231023 @ Since (" 1.5.0" )
10241024 def summary : LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
10251025 throw new SparkException (" No training summary available for this LogisticRegressionModel" )
10261026 }
10271027
1028- @ Since (" 2.2.0" )
1028+ /**
1029+ * Gets summary of model on training set. An exception is thrown
1030+ * if `trainingSummary == None` or it is a multiclass model.
1031+ */
1032+ @ Since (" 2.3.0" )
10291033 def binarySummary : BinaryLogisticRegressionTrainingSummary = summary match {
10301034 case b : BinaryLogisticRegressionTrainingSummary => b
10311035 case _ =>
@@ -1357,23 +1361,23 @@ sealed trait LogisticRegressionSummary extends Serializable {
13571361 /**
13581362 * Dataframe output by the model's `transform` method.
13591363 */
1360- @ Since (" 2.3 .0" )
1364+ @ Since (" 1.5 .0" )
13611365 def predictions : DataFrame
13621366
13631367 /** Field in "predictions" which gives the probability of each class as a vector. */
1364- @ Since (" 2.3 .0" )
1368+ @ Since (" 1.5 .0" )
13651369 def probabilityCol : String
13661370
13671371 /** Field in "predictions" which gives the prediction of each class. */
13681372 @ Since (" 2.3.0" )
13691373 def predictionCol : String
13701374
13711375 /** Field in "predictions" which gives the true label of each instance (if available). */
1372- @ Since (" 2.3 .0" )
1376+ @ Since (" 1.5 .0" )
13731377 def labelCol : String
13741378
13751379 /** Field in "predictions" which gives the features of each instance as a vector. */
1376- @ Since (" 2.3 .0" )
1380+ @ Since (" 1.6 .0" )
13771381 def featuresCol : String
13781382
13791383 @ transient private val multiclassMetrics = {
@@ -1384,6 +1388,17 @@ sealed trait LogisticRegressionSummary extends Serializable {
13841388 .rdd.map { case Row (prediction : Double , label : Double ) => (prediction, label) })
13851389 }
13861390
1391+ /**
1392+ * Returns the sequence of labels in ascending order
1393+ *
1394+ * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
1395+ * training set is missing a label, then all of the arrays over labels
1396+ * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
1397+ * expected numClasses.
1398+ */
1399+ @ Since (" 2.3.0" )
1400+ def labels : Array [Double ] = multiclassMetrics.labels
1401+
13871402 /** Returns true positive rate for each label (category). */
13881403 @ Since (" 2.3.0" )
13891404 def truePositiveRateByLabel : Array [Double ] = recallByLabel
@@ -1561,7 +1576,6 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
15611576 with LogisticRegressionTrainingSummary
15621577
15631578/**
1564- * :: Experimental ::
15651579 * Multiclass logistic regression training results.
15661580 *
15671581 * @param predictions dataframe output by the model's `transform` method.
@@ -1574,18 +1588,17 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
15741588 * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
15751589 */
15761590private class LogisticRegressionTrainingSummaryImpl (
1577- override val predictions : DataFrame ,
1578- override val probabilityCol : String ,
1579- override val predictionCol : String ,
1580- override val labelCol : String ,
1581- override val featuresCol : String ,
1582- val objectiveHistory : Array [Double ])
1591+ predictions : DataFrame ,
1592+ probabilityCol : String ,
1593+ predictionCol : String ,
1594+ labelCol : String ,
1595+ featuresCol : String ,
1596+ override val objectiveHistory : Array [Double ])
15831597 extends LogisticRegressionSummaryImpl (
15841598 predictions, probabilityCol, predictionCol, labelCol, featuresCol)
15851599 with LogisticRegressionTrainingSummary
15861600
15871601/**
1588- * :: Experimental ::
15891602 * Multiclass logistic regression results for a given model.
15901603 *
15911604 * @param predictions dataframe output by the model's `transform` method.
@@ -1605,7 +1618,6 @@ private class LogisticRegressionSummaryImpl(
16051618 extends LogisticRegressionSummary
16061619
16071620/**
1608- * :: Experimental ::
16091621 * Binary logistic regression training results.
16101622 *
16111623 * @param predictions dataframe output by the model's `transform` method.
@@ -1618,18 +1630,17 @@ private class LogisticRegressionSummaryImpl(
16181630 * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
16191631 */
16201632private class BinaryLogisticRegressionTrainingSummaryImpl (
1621- override val predictions : DataFrame ,
1622- override val probabilityCol : String ,
1623- override val predictionCol : String ,
1624- override val labelCol : String ,
1625- override val featuresCol : String ,
1633+ predictions : DataFrame ,
1634+ probabilityCol : String ,
1635+ predictionCol : String ,
1636+ labelCol : String ,
1637+ featuresCol : String ,
16261638 override val objectiveHistory : Array [Double ])
16271639 extends BinaryLogisticRegressionSummaryImpl (
16281640 predictions, probabilityCol, predictionCol, labelCol, featuresCol)
16291641 with BinaryLogisticRegressionTrainingSummary
16301642
16311643/**
1632- * :: Experimental ::
16331644 * Binary logistic regression results for a given model.
16341645 *
16351646 * @param predictions dataframe output by the model's `transform` method.
@@ -1641,11 +1652,11 @@ private class BinaryLogisticRegressionTrainingSummaryImpl(
16411652 * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
16421653 */
16431654private class BinaryLogisticRegressionSummaryImpl (
1644- @ transient override val predictions : DataFrame ,
1645- override val probabilityCol : String ,
1646- override val predictionCol : String ,
1647- override val labelCol : String ,
1648- override val featuresCol : String )
1655+ predictions : DataFrame ,
1656+ probabilityCol : String ,
1657+ predictionCol : String ,
1658+ labelCol : String ,
1659+ featuresCol : String )
16491660 extends LogisticRegressionSummaryImpl (
16501661 predictions, probabilityCol, predictionCol, labelCol, featuresCol)
16511662 with BinaryLogisticRegressionSummary
0 commit comments