-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-12566] [ML] [WIP] GLM model family, link function support in SparkR:::glm #11549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,15 +29,9 @@ setClass("PipelineModel", representation(model = "jobj")) | |
| #' @param formula A symbolic description of the model to be fitted. Currently only a few formula | ||
| #' operators are supported, including '~', '.', ':', '+', and '-'. | ||
| #' @param data DataFrame for training | ||
| #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. | ||
| #' @param family a description of the error distribution and link function to be used in the model.. | ||
| #' @param lambda Regularization parameter | ||
| #' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) | ||
| #' @param standardize Whether to standardize features before training | ||
| #' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and | ||
| #' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory | ||
| #' quasi-Newton optimization method. "normal" denotes using Normal Equation as an | ||
| #' analytical solution to the linear regression problem. The default value is "auto" | ||
| #' which means that the solver algorithm is selected automatically. | ||
| #' @param solver Currently only support "irls" which is also the default solver. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous comment was more explicit, especially with respect to 'auto' (the default). It should mention auto and irls as the two options. |
||
| #' @return a fitted MLlib model | ||
| #' @rdname glm | ||
| #' @export | ||
|
|
@@ -51,13 +45,12 @@ setClass("PipelineModel", representation(model = "jobj")) | |
| #' summary(model) | ||
| #'} | ||
| setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), | ||
| function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, | ||
| standardize = TRUE, solver = "auto") { | ||
| function(formula, family = c("gaussian", "binomial", "poisson", "gamma"), data, | ||
|
||
| lambda = 0, solver = "irls") { | ||
|
||
| family <- match.arg(family) | ||
| formula <- paste(deparse(formula), collapse="") | ||
| model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", | ||
| "fitRModelFormula", formula, data@sdf, family, lambda, | ||
| alpha, standardize, solver) | ||
| "fitGLM", formula, data@sdf, family, lambda, solver) | ||
| return(new("PipelineModel", model = model)) | ||
| }) | ||
|
|
||
|
|
@@ -124,6 +117,11 @@ setMethod("summary", signature(object = "PipelineModel"), | |
| colnames(coefficients) <- c("Estimate") | ||
| rownames(coefficients) <- unlist(features) | ||
| return(list(coefficients = coefficients)) | ||
| } else if (modelName == "GeneralizedLinearRegressionModel") { | ||
| coefficients <- as.matrix(unlist(coefficients)) | ||
| colnames(coefficients) <- c("Estimate") | ||
| rownames(coefficients) <- unlist(features) | ||
| return(list(coefficients = coefficients)) | ||
| } else if (modelName == "KMeansModel") { | ||
| modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", | ||
| "getKMeansModelSize", object@model) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,15 +17,41 @@ | |
|
|
||
| package org.apache.spark.ml.api.r | ||
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.ml.{Pipeline, PipelineModel} | ||
| import org.apache.spark.ml.attribute._ | ||
| import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} | ||
| import org.apache.spark.ml.clustering.{KMeans, KMeansModel} | ||
| import org.apache.spark.ml.feature.{RFormula, VectorAssembler} | ||
| import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} | ||
| import org.apache.spark.ml.regression._ | ||
| import org.apache.spark.sql.DataFrame | ||
|
|
||
| private[r] object SparkRWrappers { | ||
| def fitGLM( | ||
| value: String, | ||
| df: DataFrame, | ||
| family: String, | ||
| lambda: Double, | ||
| solver: String): PipelineModel = { | ||
| if (solver.trim != "irls") throw new SparkException("Currently only support irls") | ||
|
|
||
| val formula = new RFormula().setFormula(value) | ||
| val regex = "^\\s*(\\w+)\\s*(\\(\\s*link\\s*=\\s*\"(\\w+)\"\\s*\\))?\\s*$".r | ||
|
||
| val estimator = family match { | ||
| case regex(familyName, group2, linkName) => | ||
|
||
| val estimator = new GeneralizedLinearRegression() | ||
| .setFamily(familyName) | ||
| .setRegParam(lambda) | ||
| .setFitIntercept(formula.hasIntercept) | ||
| if (linkName != null) estimator.setLink(linkName) | ||
| estimator | ||
| case _ => throw new SparkException(s"Could not parse family: $family") | ||
| } | ||
|
|
||
| val pipeline = new Pipeline().setStages(Array(formula, estimator)) | ||
| pipeline.fit(df) | ||
| } | ||
|
|
||
| def fitRModelFormula( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that method is not used anymore, right? We should remove it. |
||
| value: String, | ||
| df: DataFrame, | ||
|
|
@@ -91,6 +117,12 @@ private[r] object SparkRWrappers { | |
| } | ||
| case m: KMeansModel => | ||
| m.clusterCenters.flatMap(_.toArray) | ||
| case m: GeneralizedLinearRegressionModel => | ||
| if (m.getFitIntercept) { | ||
| Array(m.intercept) ++ m.coefficients.toArray | ||
| } else { | ||
| m.coefficients.toArray | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -151,6 +183,14 @@ private[r] object SparkRWrappers { | |
| val attrs = AttributeGroup.fromStructField( | ||
| m.summary.predictions.schema(m.summary.featuresCol)) | ||
| attrs.attributes.get.map(_.name.get) | ||
| case m: GeneralizedLinearRegressionModel => | ||
| val attrs = AttributeGroup.fromStructField( | ||
| m.summary.predictions.schema(m.summary.featuresCol)) | ||
| if (m.getFitIntercept) { | ||
| Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) | ||
| } else { | ||
| attrs.attributes.get.map(_.name.get) | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -162,6 +202,8 @@ private[r] object SparkRWrappers { | |
| "LogisticRegressionModel" | ||
| case m: KMeansModel => | ||
| "KMeansModel" | ||
| case m: GeneralizedLinearRegressionModel => | ||
| "GeneralizedLinearRegressionModel" | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -208,26 +208,29 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |
| Instance(label, weight, features) | ||
| } | ||
|
|
||
| if (familyObj == Gaussian && linkObj == Identity) { | ||
| val model = if (familyObj == Gaussian && linkObj == Identity) { | ||
| // TODO: Make standardizeFeatures and standardizeLabel configurable. | ||
| val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), | ||
| standardizeFeatures = true, standardizeLabel = true) | ||
| val wlsModel = optimizer.fit(instances) | ||
| val model = copyValues( | ||
| copyValues( | ||
| new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) | ||
| .setParent(this)) | ||
| return model | ||
| } | ||
| else { | ||
| // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). | ||
| val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) | ||
| val optimizer = new IterativelyReweightedLeastSquares(initialModel, | ||
| familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) | ||
| val irlsModel = optimizer.fit(instances) | ||
| copyValues( | ||
| new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) | ||
| .setParent(this)) | ||
| } | ||
|
|
||
| // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). | ||
| val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) | ||
| val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, | ||
| $(fitIntercept), $(regParam), $(maxIter), $(tol)) | ||
| val irlsModel = optimizer.fit(instances) | ||
|
|
||
| val model = copyValues( | ||
| new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) | ||
| .setParent(this)) | ||
| val summary = new GeneralizedLinearRegressionSummary(model.transform(dataset), | ||
| $(predictionCol), $(labelCol), $(featuresCol)) | ||
| model.setSummary(summary) | ||
| model | ||
| } | ||
|
|
||
|
|
@@ -569,9 +572,46 @@ class GeneralizedLinearRegressionModel private[ml] ( | |
| familyAndLink.fitted(eta) | ||
| } | ||
|
|
||
| private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None | ||
|
|
||
| private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = { | ||
| this.trainingSummary = Some(summary) | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Gets summary of model on training set. An exception is | ||
| * thrown if `trainingSummary == None`. | ||
| */ | ||
| @Since("2.0.0") | ||
| def summary: GeneralizedLinearRegressionSummary = trainingSummary match { | ||
|
||
| case Some(summ) => summ | ||
| case None => | ||
| throw new SparkException( | ||
| "No training summary available for this GeneralizedLinearRegressionModel", | ||
| new NullPointerException()) | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { | ||
| copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) | ||
| .setParent(parent) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * GeneralizedLinearRegressionModel results evaluated on a dataset. | ||
| * | ||
| * @param predictions dataframe outputted by the model's `transform` method. | ||
| * @param predictionCol field in "predictions" which gives the prediction of each instance. | ||
| * @param labelCol field in "predictions" which gives the true label of each instance. | ||
| * @param featuresCol field in "predictions" which gives the features of each instance as a vector. | ||
| */ | ||
| @Experimental | ||
| @Since("2.0.0") | ||
| class GeneralizedLinearRegressionSummary private[regression] ( | ||
| @Since("2.0.0") @transient val predictions: DataFrame, | ||
| @Since("2.0.0") val predictionCol: String, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on a private discussion with @jkbradley , we should not expose the name of the columns in the public API. Users should be referring to the original model to get access to the column. In |
||
| @Since("2.0.0") val labelCol: String, | ||
| @Since("2.0.0") val featuresCol: String) extends Serializable | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
link to R's
familydoc