Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,9 @@ setClass("PipelineModel", representation(model = "jobj"))
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param family a description of the error distribution and link function to be used in the model..
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link to R's family doc

#' @param lambda Regularization parameter
#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
#' @param standardize Whether to standardize features before training
#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and
#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory
#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an
#' analytical solution to the linear regression problem. The default value is "auto"
#' which means that the solver algorithm is selected automatically.
#' @param solver Currently only support "irls" which is also the default solver.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous comment was more explicit, especially with respect to 'auto' (the default). It should mention auto and irls as the two options.

#' @return a fitted MLlib model
#' @rdname glm
#' @export
Expand All @@ -51,13 +45,12 @@ setClass("PipelineModel", representation(model = "jobj"))
#' summary(model)
#'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0,
standardize = TRUE, solver = "auto") {
function(formula, family = c("gaussian", "binomial", "poisson", "gamma"), data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should match R's signature for family now. We can support family/link functions that R supports. If user input binomial("logit"), we can extract the family name and the link name before we call the Scala implementation.

lambda = 0, solver = "irls") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the solver to 'auto', so that we can change the implementation of the solver without regression. However, the option 'irls' is available for users who want to use it.

As for the other options for the solver, see my comment in the ticket.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on "auto"
@thunterdb I don't see any comments in the ticket. Could you copy them over?

family <- match.arg(family)
formula <- paste(deparse(formula), collapse="")
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", formula, data@sdf, family, lambda,
alpha, standardize, solver)
"fitGLM", formula, data@sdf, family, lambda, solver)
return(new("PipelineModel", model = model))
})

Expand Down Expand Up @@ -124,6 +117,11 @@ setMethod("summary", signature(object = "PipelineModel"),
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
} else if (modelName == "GeneralizedLinearRegressionModel") {
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
} else if (modelName == "KMeansModel") {
modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansModelSize", object@model)
Expand Down
44 changes: 43 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,41 @@

package org.apache.spark.ml.api.r

import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.regression._
import org.apache.spark.sql.DataFrame

private[r] object SparkRWrappers {
def fitGLM(
value: String,
df: DataFrame,
family: String,
lambda: Double,
solver: String): PipelineModel = {
if (solver.trim != "irls") throw new SparkException("Currently only support irls")

val formula = new RFormula().setFormula(value)
val regex = "^\\s*(\\w+)\\s*(\\(\\s*link\\s*=\\s*\"(\\w+)\"\\s*\\))?\\s*$".r
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in order to minimize the escaping, you can use Scala's raw strings:

"""^\s*(\w+...\s*$""".r

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use regex here. Extract the names on R side:

> b <- binomial(link = "logit")
> b$family
[1] "binomial"
> b$link
[1] "logit"

val estimator = family match {
case regex(familyName, group2, linkName) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused: why do you need a regex here? I do not see anything special on the other side in R.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. The regex is unnecessary at here, RFormula can parse formula and handle illegal formula for glm.
You may noticed that I use regex in #11447, that is because we only support a subset of the formula in survreg currently.

val estimator = new GeneralizedLinearRegression()
.setFamily(familyName)
.setRegParam(lambda)
.setFitIntercept(formula.hasIntercept)
if (linkName != null) estimator.setLink(linkName)
estimator
case _ => throw new SparkException(s"Could not parse family: $family")
}

val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}

def fitRModelFormula(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that method is not used anymore, right? We should remove it.

value: String,
df: DataFrame,
Expand Down Expand Up @@ -91,6 +117,12 @@ private[r] object SparkRWrappers {
}
case m: KMeansModel =>
m.clusterCenters.flatMap(_.toArray)
case m: GeneralizedLinearRegressionModel =>
if (m.getFitIntercept) {
Array(m.intercept) ++ m.coefficients.toArray
} else {
m.coefficients.toArray
}
}
}

Expand Down Expand Up @@ -151,6 +183,14 @@ private[r] object SparkRWrappers {
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
attrs.attributes.get.map(_.name.get)
case m: GeneralizedLinearRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
if (m.getFitIntercept) {
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
} else {
attrs.attributes.get.map(_.name.get)
}
}
}

Expand All @@ -162,6 +202,8 @@ private[r] object SparkRWrappers {
"LogisticRegressionModel"
case m: KMeansModel =>
"KMeansModel"
case m: GeneralizedLinearRegressionModel =>
"GeneralizedLinearRegressionModel"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -208,26 +208,29 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
Instance(label, weight, features)
}

if (familyObj == Gaussian && linkObj == Identity) {
val model = if (familyObj == Gaussian && linkObj == Identity) {
// TODO: Make standardizeFeatures and standardizeLabel configurable.
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
standardizeFeatures = true, standardizeLabel = true)
val wlsModel = optimizer.fit(instances)
val model = copyValues(
copyValues(
new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
.setParent(this))
return model
}
else {
// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam))
val optimizer = new IterativelyReweightedLeastSquares(initialModel,
familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol))
val irlsModel = optimizer.fit(instances)
copyValues(
new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
.setParent(this))
}

// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam))
val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc,
$(fitIntercept), $(regParam), $(maxIter), $(tol))
val irlsModel = optimizer.fit(instances)

val model = copyValues(
new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
.setParent(this))
val summary = new GeneralizedLinearRegressionSummary(model.transform(dataset),
$(predictionCol), $(labelCol), $(featuresCol))
model.setSummary(summary)
model
}

Expand Down Expand Up @@ -569,9 +572,46 @@ class GeneralizedLinearRegressionModel private[ml] (
familyAndLink.fitted(eta)
}

private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None

private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = {
this.trainingSummary = Some(summary)
this
}

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.0.0")
def summary: GeneralizedLinearRegressionSummary = trainingSummary match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can also do

trainingSummary.getOrElse {
  throw new Exception(...)
}

case Some(summ) => summ
case None =>
throw new SparkException(
"No training summary available for this GeneralizedLinearRegressionModel",
new NullPointerException())
}

@Since("2.0.0")
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
.setParent(parent)
}
}

/**
* :: Experimental ::
* GeneralizedLinearRegressionModel results evaluated on a dataset.
*
* @param predictions dataframe outputted by the model's `transform` method.
* @param predictionCol field in "predictions" which gives the prediction of each instance.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
@Experimental
@Since("2.0.0")
class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on a private discussion with @jkbradley , we should not expose the name of the columns in the public API. Users should be referring to the original model to get access to the column. In LinearRegressionSummary, they are passed around because they are required for metrics. We do not need them here.

@Since("2.0.0") val labelCol: String,
@Since("2.0.0") val featuresCol: String) extends Serializable