Skip to content

Commit b4720aa

Browse files
committed
Merge branch 'master' into SPARK-18366
2 parents 1922472 + 55964c1 commit b4720aa

File tree

10 files changed

+696
-66
lines changed

10 files changed

+696
-66
lines changed

R/pkg/NAMESPACE

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ exportMethods("glm",
4545
"spark.als",
4646
"spark.kstest",
4747
"spark.logit",
48-
"spark.randomForest")
48+
"spark.randomForest",
49+
"spark.gbt")
4950

5051
# Job group lifecycle management methods
5152
export("setJobGroup",
@@ -353,7 +354,9 @@ export("as.DataFrame",
353354
"read.ml",
354355
"print.summary.KSTest",
355356
"print.summary.RandomForestRegressionModel",
356-
"print.summary.RandomForestClassificationModel")
357+
"print.summary.RandomForestClassificationModel",
358+
"print.summary.GBTRegressionModel",
359+
"print.summary.GBTClassificationModel")
357360

358361
export("structField",
359362
"structField.jobj",
@@ -380,6 +383,8 @@ S3method(print, summary.GeneralizedLinearRegressionModel)
380383
S3method(print, summary.KSTest)
381384
S3method(print, summary.RandomForestRegressionModel)
382385
S3method(print, summary.RandomForestClassificationModel)
386+
S3method(print, summary.GBTRegressionModel)
387+
S3method(print, summary.GBTClassificationModel)
383388
S3method(structField, character)
384389
S3method(structField, jobj)
385390
S3method(structType, jobj)

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,10 @@ setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
13431343
setGeneric("spark.gaussianMixture",
13441344
function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
13451345

1346+
#' @rdname spark.gbt
1347+
#' @export
1348+
setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") })
1349+
13461350
#' @rdname spark.glm
13471351
#' @export
13481352
setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })

R/pkg/R/mllib.R

Lines changed: 286 additions & 45 deletions
Large diffs are not rendered by default.

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,4 +949,72 @@ test_that("spark.randomForest Classification", {
949949
unlink(modelPath)
950950
})
951951

952+
test_that("spark.gbt", {
953+
# regression
954+
data <- suppressWarnings(createDataFrame(longley))
955+
model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123)
956+
predictions <- collect(predict(model, data))
957+
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
958+
63.221, 63.639, 64.989, 63.761,
959+
66.019, 67.857, 68.169, 66.513,
960+
68.655, 69.564, 69.331, 70.551),
961+
tolerance = 1e-4)
962+
stats <- summary(model)
963+
expect_equal(stats$numTrees, 20)
964+
expect_equal(stats$formula, "Employed ~ .")
965+
expect_equal(stats$numFeatures, 6)
966+
expect_equal(length(stats$treeWeights), 20)
967+
968+
modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp")
969+
write.ml(model, modelPath)
970+
expect_error(write.ml(model, modelPath))
971+
write.ml(model, modelPath, overwrite = TRUE)
972+
model2 <- read.ml(modelPath)
973+
stats2 <- summary(model2)
974+
expect_equal(stats$formula, stats2$formula)
975+
expect_equal(stats$numFeatures, stats2$numFeatures)
976+
expect_equal(stats$features, stats2$features)
977+
expect_equal(stats$featureImportances, stats2$featureImportances)
978+
expect_equal(stats$numTrees, stats2$numTrees)
979+
expect_equal(stats$treeWeights, stats2$treeWeights)
980+
981+
unlink(modelPath)
982+
983+
# classification
984+
# label must be binary - GBTClassifier currently only supports binary classification.
985+
iris2 <- iris[iris$Species != "virginica", ]
986+
data <- suppressWarnings(createDataFrame(iris2))
987+
model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification")
988+
stats <- summary(model)
989+
expect_equal(stats$numFeatures, 2)
990+
expect_equal(stats$numTrees, 20)
991+
expect_error(capture.output(stats), NA)
992+
expect_true(length(capture.output(stats)) > 6)
993+
predictions <- collect(predict(model, data))$prediction
994+
# test string prediction values
995+
expect_equal(length(grep("setosa", predictions)), 50)
996+
expect_equal(length(grep("versicolor", predictions)), 50)
997+
998+
modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp")
999+
write.ml(model, modelPath)
1000+
expect_error(write.ml(model, modelPath))
1001+
write.ml(model, modelPath, overwrite = TRUE)
1002+
model2 <- read.ml(modelPath)
1003+
stats2 <- summary(model2)
1004+
expect_equal(stats$depth, stats2$depth)
1005+
expect_equal(stats$numNodes, stats2$numNodes)
1006+
expect_equal(stats$numClasses, stats2$numClasses)
1007+
1008+
unlink(modelPath)
1009+
1010+
iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1)
1011+
df <- suppressWarnings(createDataFrame(iris2))
1012+
m <- spark.gbt(df, NumericSpecies ~ ., type = "classification")
1013+
s <- summary(m)
1014+
# test numeric prediction values
1015+
expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
1016+
expect_equal(s$numFeatures, 5)
1017+
expect_equal(s$numTrees, 20)
1018+
})
1019+
9521020
sparkR.session.stop()
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.r
19+
20+
import org.apache.hadoop.fs.Path
21+
import org.json4s._
22+
import org.json4s.JsonDSL._
23+
import org.json4s.jackson.JsonMethods._
24+
25+
import org.apache.spark.ml.{Pipeline, PipelineModel}
26+
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
27+
import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
28+
import org.apache.spark.ml.feature.{IndexToString, RFormula}
29+
import org.apache.spark.ml.linalg.Vector
30+
import org.apache.spark.ml.util._
31+
import org.apache.spark.sql.{DataFrame, Dataset}
32+
33+
private[r] class GBTClassifierWrapper private (
34+
val pipeline: PipelineModel,
35+
val formula: String,
36+
val features: Array[String]) extends MLWritable {
37+
38+
import GBTClassifierWrapper._
39+
40+
private val gbtcModel: GBTClassificationModel =
41+
pipeline.stages(1).asInstanceOf[GBTClassificationModel]
42+
43+
lazy val numFeatures: Int = gbtcModel.numFeatures
44+
lazy val featureImportances: Vector = gbtcModel.featureImportances
45+
lazy val numTrees: Int = gbtcModel.getNumTrees
46+
lazy val treeWeights: Array[Double] = gbtcModel.treeWeights
47+
48+
def summary: String = gbtcModel.toDebugString
49+
50+
def transform(dataset: Dataset[_]): DataFrame = {
51+
pipeline.transform(dataset)
52+
.drop(PREDICTED_LABEL_INDEX_COL)
53+
.drop(gbtcModel.getFeaturesCol)
54+
}
55+
56+
override def write: MLWriter = new
57+
GBTClassifierWrapper.GBTClassifierWrapperWriter(this)
58+
}
59+
60+
private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] {
61+
62+
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
63+
val PREDICTED_LABEL_COL = "prediction"
64+
65+
def fit( // scalastyle:ignore
66+
data: DataFrame,
67+
formula: String,
68+
maxDepth: Int,
69+
maxBins: Int,
70+
maxIter: Int,
71+
stepSize: Double,
72+
minInstancesPerNode: Int,
73+
minInfoGain: Double,
74+
checkpointInterval: Int,
75+
lossType: String,
76+
seed: String,
77+
subsamplingRate: Double,
78+
maxMemoryInMB: Int,
79+
cacheNodeIds: Boolean): GBTClassifierWrapper = {
80+
81+
val rFormula = new RFormula()
82+
.setFormula(formula)
83+
.setForceIndexLabel(true)
84+
RWrapperUtils.checkDataColumns(rFormula, data)
85+
val rFormulaModel = rFormula.fit(data)
86+
87+
// get feature names from output schema
88+
val schema = rFormulaModel.transform(data).schema
89+
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
90+
.attributes.get
91+
val features = featureAttrs.map(_.name.get)
92+
93+
// get label names from output schema
94+
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
95+
.asInstanceOf[NominalAttribute]
96+
val labels = labelAttr.values.get
97+
98+
// assemble and fit the pipeline
99+
val rfc = new GBTClassifier()
100+
.setMaxDepth(maxDepth)
101+
.setMaxBins(maxBins)
102+
.setMaxIter(maxIter)
103+
.setStepSize(stepSize)
104+
.setMinInstancesPerNode(minInstancesPerNode)
105+
.setMinInfoGain(minInfoGain)
106+
.setCheckpointInterval(checkpointInterval)
107+
.setLossType(lossType)
108+
.setSubsamplingRate(subsamplingRate)
109+
.setMaxMemoryInMB(maxMemoryInMB)
110+
.setCacheNodeIds(cacheNodeIds)
111+
.setFeaturesCol(rFormula.getFeaturesCol)
112+
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
113+
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
114+
115+
val idxToStr = new IndexToString()
116+
.setInputCol(PREDICTED_LABEL_INDEX_COL)
117+
.setOutputCol(PREDICTED_LABEL_COL)
118+
.setLabels(labels)
119+
120+
val pipeline = new Pipeline()
121+
.setStages(Array(rFormulaModel, rfc, idxToStr))
122+
.fit(data)
123+
124+
new GBTClassifierWrapper(pipeline, formula, features)
125+
}
126+
127+
override def read: MLReader[GBTClassifierWrapper] = new GBTClassifierWrapperReader
128+
129+
override def load(path: String): GBTClassifierWrapper = super.load(path)
130+
131+
class GBTClassifierWrapperWriter(instance: GBTClassifierWrapper)
132+
extends MLWriter {
133+
134+
override protected def saveImpl(path: String): Unit = {
135+
val rMetadataPath = new Path(path, "rMetadata").toString
136+
val pipelinePath = new Path(path, "pipeline").toString
137+
138+
val rMetadata = ("class" -> instance.getClass.getName) ~
139+
("formula" -> instance.formula) ~
140+
("features" -> instance.features.toSeq)
141+
val rMetadataJson: String = compact(render(rMetadata))
142+
143+
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
144+
instance.pipeline.save(pipelinePath)
145+
}
146+
}
147+
148+
class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper] {
149+
150+
override def load(path: String): GBTClassifierWrapper = {
151+
implicit val format = DefaultFormats
152+
val rMetadataPath = new Path(path, "rMetadata").toString
153+
val pipelinePath = new Path(path, "pipeline").toString
154+
val pipeline = PipelineModel.load(pipelinePath)
155+
156+
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
157+
val rMetadata = parse(rMetadataStr)
158+
val formula = (rMetadata \ "formula").extract[String]
159+
val features = (rMetadata \ "features").extract[Array[String]]
160+
161+
new GBTClassifierWrapper(pipeline, formula, features)
162+
}
163+
}
164+
}

0 commit comments

Comments
 (0)