Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,20 @@ test_that("summary coefficients match with native glm", {
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})

test_that("summary coefficients match with native glm of family 'binomial'", {
df <- createDataFrame(sqlContext, iris)
training <- filter(df, df$Species != "setosa")
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
family = "binomial"))
coefs <- as.vector(stats$coefficients)

rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit"))))

expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Sepal_Width")))
})
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ class LogisticRegression(override val uid: String)
model.transform(dataset),
$(probabilityCol),
$(labelCol),
$(featuresCol),
objectiveHistory)
model.setSummary(logRegSummary)
}
Expand Down Expand Up @@ -452,7 +453,8 @@ class LogisticRegressionModel private[ml] (
*/
// TODO: decide on a good name before exposing to public API
private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
new BinaryLogisticRegressionSummary(
this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol))
}

/**
Expand Down Expand Up @@ -614,9 +616,12 @@ sealed trait LogisticRegressionSummary extends Serializable {
/** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
def probabilityCol: String

/** Field in "predictions" which gives the the true label of each instance. */
/** Field in "predictions" which gives the true label of each instance. */
def labelCol: String

/** Field in "predictions" which gives the features of each instance as a vector. */
def featuresCol: String

}

/**
Expand All @@ -626,15 +631,17 @@ sealed trait LogisticRegressionSummary extends Serializable {
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each instance as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
class BinaryLogisticRegressionTrainingSummary private[classification] (
predictions: DataFrame,
probabilityCol: String,
labelCol: String,
featuresCol: String,
val objectiveHistory: Array[Double])
extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
with LogisticRegressionTrainingSummary {

}
Expand All @@ -646,12 +653,14 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each instance.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
@Experimental
class BinaryLogisticRegressionSummary private[classification] (
@transient override val predictions: DataFrame,
override val probabilityCol: String,
override val labelCol: String) extends LogisticRegressionSummary {
override val labelCol: String,
override val featuresCol: String) extends LogisticRegressionSummary {

private val sqlContext = predictions.sqlContext
import sqlContext.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ private[r] object SparkRWrappers {
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
case _: LogisticRegressionModel =>
throw new UnsupportedOperationException(
"No features names available for LogisticRegressionModel") // SPARK-9492
case m: LogisticRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
}
}
}
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.classification.LogisticAggregator.add"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.classification.LogisticAggregator.count")
"org.apache.spark.ml.classification.LogisticAggregator.count"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol")
) ++ Seq(
// SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message.
// This class is marked as `private` but MiMa still seems to be confused by the change.
Expand Down