Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
WeichenXu123 committed Aug 21, 2017
commit b6cde56f18caa85f79b9cd0dc604ae1a46fd4948
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ class LogisticRegression @Since("1.2.0") (
numClasses, isMultinomial))

val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
val logRegSummary = if (numClasses <=2) {
val logRegSummary = if (numClasses <= 2) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use isMultinomial since it is easier to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley The logic here isn't if (!isMultinomial) ... else ...
it is if (!isMultinomial || (isMultinomial && numClasses <= 2)) ... else ...
see testcase test("check summary types for binary and multiclass")

new BinaryLogisticRegressionTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
Expand Down Expand Up @@ -1017,15 +1017,19 @@ class LogisticRegressionModel private[spark] (
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
* Gets summary of model on training set. An exception is thrown
* if `trainingSummary == None`.
*/
@Since("1.5.0")
def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
throw new SparkException("No training summary available for this LogisticRegressionModel")
}

@Since("2.2.0")
/**
* Gets summary of model on training set. An exception is thrown
* if `trainingSummary == None` or it is a multiclass model.
*/
@Since("2.3.0")
def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior we are implementing is non-trivial. We need to add tests to ensure that everything happens as expected. This is an example:

test("binary and multiclass summary") {
    val lr = new LogisticRegression()
      .setFamily("binomial")

    val blorModel = lr.fit(smallBinaryDataset)
    assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
    assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])

    val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
    assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummaryImpl])
    withClue("cannot get binary summary for multiclass model") {
      intercept[RuntimeException] {
        mlorModel.binarySummary
      }
    }

    val blorSummary = blorModel.evaluate(smallBinaryDataset)
    val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
    assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummaryImpl])
    assert(mlorSummary.isInstanceOf[LogisticRegressionSummaryImpl])
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I ask why we need a separate binarySummary? If we get a binary classification model, the trainingSummary should be the corresponding binary summary. Is that correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For convenience. Otherwise, users need to call model.summary.asInstanceOf[BinaryLogisticRegressionTrainingSummary] in order to access the extra, binary-specific methods.

case b: BinaryLogisticRegressionTrainingSummary => b
case _ =>
Expand Down Expand Up @@ -1357,23 +1361,23 @@ sealed trait LogisticRegressionSummary extends Serializable {
/**
* Dataframe output by the model's `transform` method.
*/
@Since("2.3.0")
@Since("1.5.0")
def predictions: DataFrame

/** Field in "predictions" which gives the probability of each class as a vector. */
@Since("2.3.0")
@Since("1.5.0")
def probabilityCol: String

/** Field in "predictions" which gives the prediction of each class. */
@Since("2.3.0")
def predictionCol: String

/** Field in "predictions" which gives the true label of each instance (if available). */
@Since("2.3.0")
@Since("1.5.0")
def labelCol: String

/** Field in "predictions" which gives the features of each instance as a vector. */
@Since("2.3.0")
@Since("1.6.0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current master code featuresCol is marked "1.6.0".
I do not remove the @since into concrete impl class because they are all private.

def featuresCol: String

@transient private val multiclassMetrics = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MulticlassMetrics provides a labels field which returns the list of labels. In most cases, this will be values {0.0, 1.0, ..., numClasses-1}. However, if the training set is missing a label, then all of the arrays over labels (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the expected numClasses. In the future, it'd be nice to fix this by having them always be of length numClasses. For now, how about we provide the labels field with this kind of explanation?

Expand All @@ -1384,6 +1388,17 @@ sealed trait LogisticRegressionSummary extends Serializable {
.rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) })
}

/**
* Returns the sequence of labels in ascending order
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarify: "Returns the sequence of labels in ascending order. This order matches the order used in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel."

*
* Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
* training set is missing a label, then all of the arrays over labels
* (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
* expected numClasses.
*/
@Since("2.3.0")
def labels: Array[Double] = multiclassMetrics.labels

/** Returns true positive rate for each label (category). */
@Since("2.3.0")
def truePositiveRateByLabel: Array[Double] = recallByLabel
Expand Down Expand Up @@ -1561,7 +1576,6 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
with LogisticRegressionTrainingSummary

/**
* :: Experimental ::
* Multiclass logistic regression training results.
*
* @param predictions dataframe output by the model's `transform` method.
Expand All @@ -1574,18 +1588,17 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class LogisticRegressionTrainingSummaryImpl(
override val predictions: DataFrame,
override val probabilityCol: String,
override val predictionCol: String,
override val labelCol: String,
override val featuresCol: String,
val objectiveHistory: Array[Double])
predictions: DataFrame,
probabilityCol: String,
predictionCol: String,
labelCol: String,
featuresCol: String,
override val objectiveHistory: Array[Double])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This val cannot be removed because in base class it is a def so override here is required.

extends LogisticRegressionSummaryImpl(
predictions, probabilityCol, predictionCol, labelCol, featuresCol)
with LogisticRegressionTrainingSummary

/**
* :: Experimental ::
* Multiclass logistic regression results for a given model.
*
* @param predictions dataframe output by the model's `transform` method.
Expand All @@ -1605,7 +1618,6 @@ private class LogisticRegressionSummaryImpl(
extends LogisticRegressionSummary

/**
* :: Experimental ::
* Binary logistic regression training results.
*
* @param predictions dataframe output by the model's `transform` method.
Expand All @@ -1618,18 +1630,17 @@ private class LogisticRegressionSummaryImpl(
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class BinaryLogisticRegressionTrainingSummaryImpl(
override val predictions: DataFrame,
override val probabilityCol: String,
override val predictionCol: String,
override val labelCol: String,
override val featuresCol: String,
predictions: DataFrame,
probabilityCol: String,
predictionCol: String,
labelCol: String,
featuresCol: String,
override val objectiveHistory: Array[Double])
extends BinaryLogisticRegressionSummaryImpl(
predictions, probabilityCol, predictionCol, labelCol, featuresCol)
with BinaryLogisticRegressionTrainingSummary

/**
* :: Experimental ::
* Binary logistic regression results for a given model.
*
* @param predictions dataframe output by the model's `transform` method.
Expand All @@ -1641,11 +1652,11 @@ private class BinaryLogisticRegressionTrainingSummaryImpl(
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
private class BinaryLogisticRegressionSummaryImpl(
@transient override val predictions: DataFrame,
override val probabilityCol: String,
override val predictionCol: String,
override val labelCol: String,
override val featuresCol: String)
predictions: DataFrame,
probabilityCol: String,
predictionCol: String,
labelCol: String,
featuresCol: String)
extends LogisticRegressionSummaryImpl(
predictions, probabilityCol, predictionCol, labelCol, featuresCol)
with BinaryLogisticRegressionSummary
Original file line number Diff line number Diff line change
Expand Up @@ -213,43 +213,44 @@ class LogisticRegressionSuite
case (family, dataset) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are not using "dataset" in this. However, this logic should be the same for both binomial and multinomial families, so I'm OK if you just test the binomial case here.

lr.setFamily(family)
lr.setProbabilityCol("").setPredictionCol("prediction")
val modelNoProb = lr.fit(smallBinaryDataset)
val modelNoProb = lr.fit(dataset)
checkSummarySchema(modelNoProb, Seq("probability_"))

lr.setProbabilityCol("probability").setPredictionCol("")
val modelNoPred = lr.fit(smallBinaryDataset)
val modelNoPred = lr.fit(dataset)
checkSummarySchema(modelNoPred, Seq("prediction_"))

lr.setProbabilityCol("").setPredictionCol("")
val modelNoPredNoProb = lr.fit(smallBinaryDataset)
val modelNoPredNoProb = lr.fit(dataset)
checkSummarySchema(modelNoPredNoProb, Seq("prediction_", "probability_"))
}
}

test("check summary types for binary and multiclass") {
val lr = new LogisticRegression()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set maxIter = 1 for speed

.setFamily("binomial")
.setMaxIter(1)

val blorModel = lr.fit(smallBinaryDataset)
assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])

val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummaryImpl])
assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary])
withClue("cannot get binary summary for multiclass model") {
intercept[RuntimeException] {
mlorModel.binarySummary
}
}

val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummaryImpl])
assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])

val blorSummary = blorModel.evaluate(smallBinaryDataset)
val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummaryImpl])
assert(mlorSummary.isInstanceOf[LogisticRegressionSummaryImpl])
assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummary])
assert(mlorSummary.isInstanceOf[LogisticRegressionSummary])
}

test("setThreshold, getThreshold") {
Expand Down