Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
SparkR:::glm supports family and link function
  • Loading branch information
yanboliang committed Apr 12, 2016
commit d6e95c06ef261eeff10c8e5032707f0c50c51457
112 changes: 55 additions & 57 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
#' @export
setClass("PipelineModel", representation(model = "jobj"))

#' @title S4 class that represents a GeneralizedLinearRegressionModel
#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper
#' @export
setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))

#' @title S4 class that represents a NaiveBayesModel
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
#' @export
Expand Down Expand Up @@ -66,14 +71,55 @@ setClass("KMeansModel", representation(jobj = "jobj"))
#' summary(model)
#'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0,
standardize = TRUE, solver = "auto") {
family <- match.arg(family)
function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) {
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
if (is.function(family)) {
family <- family()
}
if (is.null(family$family)) {
print(family)
stop("'family' not recognized")
}

formula <- paste(deparse(formula), collapse = "")
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", formula, data@sdf, family, lambda,
alpha, standardize, solver)
return(new("PipelineModel", model = model))

jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
epsilon, maxit)
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
})

#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
#'
#' @param object A fitted MLlib model
#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family
#' or a list with 'coefficients' component for binomial family. \cr
#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals
#' of the estimation, the 'coefficients' gives the estimated coefficients and their
#' estimated standard errors, t values and p-values. (It only available when model
#' fitted by normal solver.) \cr
#' For binomial family: the 'coefficients' gives the estimated coefficients.
#' See summary.glm for more information. \cr
#' @rdname summary
#' @export
#' @examples
#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#' }
setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
function(object, ...) {
jobj <- object@jobj
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Value")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
})

#' Make predictions from a model
Expand All @@ -91,9 +137,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
#' predicted <- predict(model, testData)
#' showDF(predicted)
#'}
setMethod("predict", signature(object = "PipelineModel"),
setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})

#' Make predictions from a naive Bayes model
Expand All @@ -116,54 +162,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})

#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
#'
#' @param object A fitted MLlib model
#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family
#' or a list with 'coefficients' component for binomial family. \cr
#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals
#' of the estimation, the 'coefficients' gives the estimated coefficients and their
#' estimated standard errors, t values and p-values. (It only available when model
#' fitted by normal solver.) \cr
#' For binomial family: the 'coefficients' gives the estimated coefficients.
#' See summary.glm for more information. \cr
#' @rdname summary
#' @export
#' @examples
#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(object = "PipelineModel"),
function(object, ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelName", object@model)
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", object@model)
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelCoefficients", object@model)
if (modelName == "LinearRegressionModel") {
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelDevianceResiduals", object@model)
devianceResiduals <- matrix(devianceResiduals, nrow = 1)
colnames(devianceResiduals) <- c("Min", "Max")
rownames(devianceResiduals) <- rep("", times = 1)
coefficients <- matrix(coefficients, ncol = 4)
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
rownames(coefficients) <- unlist(features)
return(list(devianceResiduals = devianceResiduals, coefficients = coefficients))
} else if (modelName == "LogisticRegressionModel") {
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
})

#' Get the summary of a naive Bayes model
#'
#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.r

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression._
import org.apache.spark.sql._

private[r] class GeneralizedLinearRegressionWrapper private (
pipeline: PipelineModel,
val features: Array[String]) {

private val glm: GeneralizedLinearRegressionModel =
pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]

lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
Array(glm.intercept) ++ glm.coefficients.toArray
} else {
glm.coefficients.toArray
}

lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
Array("(Intercept)") ++ features
} else {
features
}

def transform(dataset: DataFrame): DataFrame = {
pipeline.transform(dataset).drop(glm.getFeaturesCol)
}
}

private[r] object GeneralizedLinearRegressionWrapper {

def fit(
formula: String,
data: DataFrame,
family: String,
link: String,
epsilon: Double,
maxit: Int): GeneralizedLinearRegressionWrapper = {

val rFormula = new RFormula()
.setFormula(formula)
val rFormulaModel = rFormula.fit(data)

// get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
// assemble and fit the pipeline
val glm = new GeneralizedLinearRegression()
.setFamily(family)
.setLink(link)
.setFitIntercept(rFormula.hasIntercept)
.setTol(epsilon)
.setMaxIter(maxit)

val pipeline = new Pipeline()
.setStages(Array(rFormula, glm))
.fit(data)

new GeneralizedLinearRegressionWrapper(pipeline, features)
}
}