diff --git a/.rat-excludes b/.rat-excludes
index c0f81b57fe09..aa008e6e920f 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -28,6 +28,7 @@ spark-env.sh
spark-env.cmd
spark-env.sh.template
log4j-defaults.properties
+log4j-defaults-repl.properties
bootstrap-tooltip.js
jquery-1.11.1.min.js
d3.min.js
@@ -80,5 +81,8 @@ local-1425081759269/*
local-1426533911241/*
local-1426633911242/*
local-1430917381534/*
+local-1430917381535_1
+local-1430917381535_2
DESCRIPTION
NAMESPACE
+test_support/*
diff --git a/LICENSE b/LICENSE
index 9d1b00beff74..d0cd0dcb4bdb 100644
--- a/LICENSE
+++ b/LICENSE
@@ -853,6 +853,52 @@ and
Vis.js may be distributed under either license.
+========================================================================
+For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js):
+========================================================================
+Copyright (c) 2013 Chris Pettitt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+========================================================================
+For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js):
+========================================================================
+Copyright (c) 2012-2013 Chris Pettitt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
========================================================================
BSD-style licenses
========================================================================
diff --git a/R/create-docs.sh b/R/create-docs.sh
index 4194172a2e11..6a4687b06ecb 100755
--- a/R/create-docs.sh
+++ b/R/create-docs.sh
@@ -23,14 +23,14 @@
# After running this script the html docs can be found in
# $SPARK_HOME/R/pkg/html
+set -o pipefail
+set -e
+
# Figure out where the script is
export FWDIR="$(cd "`dirname "$0"`"; pwd)"
pushd $FWDIR
-# Generate Rd file
-Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))'
-
-# Install the package
+# Install the package (this will also generate the Rd files)
./install-dev.sh
# Now create HTML files
diff --git a/R/install-dev.sh b/R/install-dev.sh
index 55ed6f4be1a4..1edd551f8d24 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -26,11 +26,20 @@
# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory
# to load the SparkR package on the worker nodes.
+set -o pipefail
+set -e
FWDIR="$(cd `dirname $0`; pwd)"
LIB_DIR="$FWDIR/lib"
mkdir -p $LIB_DIR
-# Install R
+pushd $FWDIR
+
+# Generate Rd files if devtools is installed
+Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }'
+
+# Install SparkR to $LIB_DIR
R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
+
+popd
diff --git a/R/log4j.properties b/R/log4j.properties
index 701adb2a3da1..cce8d9152d32 100644
--- a/R/log4j.properties
+++ b/R/log4j.properties
@@ -19,7 +19,7 @@
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=true
-log4j.appender.file.file=R-unit-tests.log
+log4j.appender.file.file=R/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 411126a37795..f9447f6c3288 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -19,9 +19,11 @@ exportMethods("arrange",
"count",
"describe",
"distinct",
+ "dropna",
"dtypes",
"except",
"explain",
+ "fillna",
"filter",
"first",
"group_by",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index ed8093c80d36..0af5cb8881e3 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1314,9 +1314,8 @@ setMethod("except",
#' write.df(df, "myfile", "parquet", "overwrite")
#' }
setMethod("write.df",
- signature(df = "DataFrame", path = 'character', source = 'character',
- mode = 'character'),
- function(df, path = NULL, source = NULL, mode = "append", ...){
+ signature(df = "DataFrame", path = 'character'),
+ function(df, path, source = NULL, mode = "append", ...){
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
@@ -1338,9 +1337,8 @@ setMethod("write.df",
#' @aliases saveDF
#' @export
setMethod("saveDF",
- signature(df = "DataFrame", path = 'character', source = 'character',
- mode = 'character'),
- function(df, path = NULL, source = NULL, mode = "append", ...){
+ signature(df = "DataFrame", path = 'character'),
+ function(df, path, source = NULL, mode = "append", ...){
write.df(df, path, source, mode, ...)
})
@@ -1431,3 +1429,128 @@ setMethod("describe",
sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
dataFrame(sdf)
})
+
+#' dropna
+#'
+#' Returns a new DataFrame omitting rows with null values.
+#'
+#' @param x A SparkSQL DataFrame.
+#' @param how "any" or "all".
+#' if "any", drop a row if it contains any nulls.
+#' if "all", drop a row only if all its values are null.
+#' if minNonNulls is specified, how is ignored.
+#' @param minNonNulls If specified, drop rows that have less than
+#' minNonNulls non-null values.
+#' This overwrites the how parameter.
+#' @param cols Optional list of column names to consider.
+#' @return A DataFrame
+#'
+#' @rdname nafunctions
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' dropna(df)
+#' }
+setMethod("dropna",
+ signature(x = "DataFrame"),
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ how <- match.arg(how)
+ if (is.null(cols)) {
+ cols <- columns(x)
+ }
+ if (is.null(minNonNulls)) {
+ minNonNulls <- if (how == "any") { length(cols) } else { 1 }
+ }
+
+ naFunctions <- callJMethod(x@sdf, "na")
+ sdf <- callJMethod(naFunctions, "drop",
+ as.integer(minNonNulls), listToSeq(as.list(cols)))
+ dataFrame(sdf)
+ })
+
+#' @aliases dropna
+#' @export
+setMethod("na.omit",
+ signature(x = "DataFrame"),
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ dropna(x, how, minNonNulls, cols)
+ })
+
+#' fillna
+#'
+#' Replace null values.
+#'
+#' @param x A SparkSQL DataFrame.
+#' @param value Value to replace null values with.
+#' Should be an integer, numeric, character or named list.
+#' If the value is a named list, then cols is ignored and
+#' value must be a mapping from column name (character) to
+#' replacement value. The replacement value must be an
+#' integer, numeric or character.
+#' @param cols optional list of column names to consider.
+#' Columns specified in cols that do not have matching data
+#' type are ignored. For example, if value is a character, and
+#' subset contains a non-character column, then the non-character
+#' column is simply ignored.
+#' @return A DataFrame
+#'
+#' @rdname nafunctions
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' fillna(df, 1)
+#' fillna(df, list("age" = 20, "name" = "unknown"))
+#' }
+setMethod("fillna",
+ signature(x = "DataFrame"),
+ function(x, value, cols = NULL) {
+ if (!(class(value) %in% c("integer", "numeric", "character", "list"))) {
+ stop("value should be an integer, numeric, charactor or named list.")
+ }
+
+ if (class(value) == "list") {
+ # Check column names in the named list
+ colNames <- names(value)
+ if (length(colNames) == 0 || !all(colNames != "")) {
+ stop("value should be an a named list with each name being a column name.")
+ }
+
+ # Convert to the named list to an environment to be passed to JVM
+ valueMap <- new.env()
+ for (col in colNames) {
+ # Check each item in the named list is of valid type
+ v <- value[[col]]
+ if (!(class(v) %in% c("integer", "numeric", "character"))) {
+ stop("Each item in value should be an integer, numeric or charactor.")
+ }
+ valueMap[[col]] <- v
+ }
+
+ # When value is a named list, caller is expected not to pass in cols
+ if (!is.null(cols)) {
+ warning("When value is a named list, cols is ignored!")
+ cols <- NULL
+ }
+
+ value <- valueMap
+ } else if (is.integer(value)) {
+ # Cast an integer to a numeric
+ value <- as.numeric(value)
+ }
+
+ naFunctions <- callJMethod(x@sdf, "na")
+ sdf <- if (length(cols) == 0) {
+ callJMethod(naFunctions, "fill", value)
+ } else {
+ callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols)))
+ }
+ dataFrame(sdf)
+ })
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 36cc61287587..22a4b5bf86eb 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -452,20 +452,31 @@ dropTempTable <- function(sqlContext, tableName) {
#' df <- read.df(sqlContext, "path/to/file.json", source = "json")
#' }
-read.df <- function(sqlContext, path = NULL, source = NULL, ...) {
+read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) {
options <- varargsToEnv(...)
if (!is.null(path)) {
options[['path']] <- path
}
- sdf <- callJMethod(sqlContext, "load", source, options)
+ if (is.null(source)) {
+ sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
+ source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ }
+ if (!is.null(schema)) {
+ stopifnot(class(schema) == "structType")
+ sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source,
+ schema$jobj, options)
+ } else {
+ sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options)
+ }
dataFrame(sdf)
}
#' @aliases loadDF
#' @export
-loadDF <- function(sqlContext, path = NULL, source = NULL, ...) {
- read.df(sqlContext, path, source, ...)
+loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) {
+ read.df(sqlContext, path, source, schema, ...)
}
#' Create an external table
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index a23d3b217b2f..12e09176c9f9 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -396,6 +396,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") })
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
+#' @rdname nafunctions
+#' @export
+setGeneric("dropna",
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ standardGeneric("dropna")
+ })
+
+#' @rdname nafunctions
+#' @export
+setGeneric("na.omit",
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ standardGeneric("na.omit")
+ })
+
#' @rdname schema
#' @export
setGeneric("dtypes", function(x) { standardGeneric("dtypes") })
@@ -408,6 +422,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") })
#' @export
setGeneric("except", function(x, y) { standardGeneric("except") })
+#' @rdname nafunctions
+#' @export
+setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") })
+
#' @rdname filter
#' @export
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
@@ -482,11 +500,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) {
#' @rdname write.df
#' @export
-setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") })
+setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
#' @rdname write.df
#' @export
-setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") })
+setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") })
#' @rdname schema
#' @export
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index c53d0a961016..3169d7968f8f 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -37,6 +37,14 @@ writeObject <- function(con, object, writeType = TRUE) {
# passing in vectors as arrays and instead require arrays to be passed
# as lists.
type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt")
+ # Checking types is needed here, since ‘is.na’ only handles atomic vectors,
+ # lists and pairlists
+ if (type %in% c("integer", "character", "logical", "double", "numeric")) {
+ if (is.na(object)) {
+ object <- NULL
+ type <- "NULL"
+ }
+ }
if (writeType) {
writeType(con, type)
}
@@ -160,6 +168,14 @@ writeList <- function(con, arr) {
}
}
+# Used to pass arrays where the elements can be of different types
+writeGenericList <- function(con, list) {
+ writeInt(con, length(list))
+ for (elem in list) {
+ writeObject(con, elem)
+ }
+}
+
# Used to pass in hash maps required on Java side.
writeEnv <- function(con, env) {
len <- length(env)
@@ -168,7 +184,7 @@ writeEnv <- function(con, env) {
if (len > 0) {
writeList(con, as.list(ls(env)))
vals <- lapply(ls(env), function(x) { env[[x]] })
- writeList(con, as.list(vals))
+ writeGenericList(con, as.list(vals))
}
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 68387f0f5365..5ced7c688f98 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -225,14 +225,21 @@ sparkR.init <- function(
#' sqlContext <- sparkRSQL.init(sc)
#'}
-sparkRSQL.init <- function(jsc) {
+sparkRSQL.init <- function(jsc = NULL) {
if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
return(get(".sparkRSQLsc", envir = .sparkREnv))
}
+ # If jsc is NULL, create a Spark Context
+ sc <- if (is.null(jsc)) {
+ sparkR.init()
+ } else {
+ jsc
+ }
+
sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
- "createSQLContext",
- jsc)
+ "createSQLContext",
+ sc)
assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv)
sqlContext
}
@@ -249,12 +256,19 @@ sparkRSQL.init <- function(jsc) {
#' sqlContext <- sparkRHive.init(sc)
#'}
-sparkRHive.init <- function(jsc) {
+sparkRHive.init <- function(jsc = NULL) {
if (exists(".sparkRHivesc", envir = .sparkREnv)) {
return(get(".sparkRHivesc", envir = .sparkREnv))
}
- ssc <- callJMethod(jsc, "sc")
+ # If jsc is NULL, create a Spark Context
+ sc <- if (is.null(jsc)) {
+ sparkR.init()
+ } else {
+ jsc
+ }
+
+ ssc <- callJMethod(sc, "sc")
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
}, error = function(err) {
diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R
index ca94f1d4e7fd..773b6ecf582d 100644
--- a/R/pkg/inst/profile/shell.R
+++ b/R/pkg/inst/profile/shell.R
@@ -24,7 +24,7 @@
old <- getOption("defaultPackages")
options(defaultPackages = c(old, "SparkR"))
- sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = ""))
+ sc <- SparkR::sparkR.init()
assign("sc", sc, envir=.GlobalEnv)
sqlContext <- SparkR::sparkRSQL.init(sc)
assign("sqlContext", sqlContext, envir=.GlobalEnv)
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 1857e636e857..8946348ef801 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -32,6 +32,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet")
writeLines(mockLines, jsonPath)
+# For test nafunctions, like dropna(), fillna(),...
+mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}",
+ "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}",
+ "{\"name\":\"David\",\"age\":60,\"height\":null}",
+ "{\"name\":\"Amy\",\"age\":null,\"height\":null}",
+ "{\"name\":null,\"age\":null,\"height\":null}")
+jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp")
+writeLines(mockLinesNa, jsonPathNa)
+
test_that("infer types", {
expect_equal(infer_type(1L), "integer")
expect_equal(infer_type(1.0), "double")
@@ -92,6 +101,43 @@ test_that("create DataFrame from RDD", {
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
})
+test_that("convert NAs to null type in DataFrames", {
+ rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L)))
+ df <- createDataFrame(sqlContext, rdd, list("a", "b"))
+ expect_true(is.na(collect(df)[2, "a"]))
+ expect_equal(collect(df)[2, "b"], 4L)
+
+ l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L))
+ df <- createDataFrame(sqlContext, l)
+ expect_equal(collect(df)[2, "x"], 1L)
+ expect_true(is.na(collect(df)[2, "y"]))
+
+ rdd <- parallelize(sc, list(list(1, 2), list(NA, 4)))
+ df <- createDataFrame(sqlContext, rdd, list("a", "b"))
+ expect_true(is.na(collect(df)[2, "a"]))
+ expect_equal(collect(df)[2, "b"], 4)
+
+ l <- data.frame(x = 1, y = c(1, NA_real_, 3))
+ df <- createDataFrame(sqlContext, l)
+ expect_equal(collect(df)[2, "x"], 1)
+ expect_true(is.na(collect(df)[2, "y"]))
+
+ l <- list("a", "b", NA, "d")
+ df <- createDataFrame(sqlContext, l)
+ expect_true(is.na(collect(df)[3, "_1"]))
+ expect_equal(collect(df)[4, "_1"], "d")
+
+ l <- list("a", "b", NA_character_, "d")
+ df <- createDataFrame(sqlContext, l)
+ expect_true(is.na(collect(df)[3, "_1"]))
+ expect_equal(collect(df)[4, "_1"], "d")
+
+ l <- list(TRUE, FALSE, NA, TRUE)
+ df <- createDataFrame(sqlContext, l)
+ expect_true(is.na(collect(df)[3, "_1"]))
+ expect_equal(collect(df)[4, "_1"], TRUE)
+})
+
test_that("toDF", {
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
df <- toDF(rdd, list("a", "b"))
@@ -495,6 +541,19 @@ test_that("read.df() from json file", {
df <- read.df(sqlContext, jsonPath, "json")
expect_true(inherits(df, "DataFrame"))
expect_true(count(df) == 3)
+
+ # Check if we can apply a user defined schema
+ schema <- structType(structField("name", type = "string"),
+ structField("age", type = "double"))
+
+ df1 <- read.df(sqlContext, jsonPath, "json", schema)
+ expect_true(inherits(df1, "DataFrame"))
+ expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double")))
+
+ # Run the same with loadDF
+ df2 <- loadDF(sqlContext, jsonPath, "json", schema)
+ expect_true(inherits(df2, "DataFrame"))
+ expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double")))
})
test_that("write.df() as parquet file", {
@@ -765,5 +824,105 @@ test_that("describe() on a DataFrame", {
expect_equal(collect(stats)[5, "age"], "30")
})
+test_that("dropna() on a DataFrame", {
+ df <- jsonFile(sqlContext, jsonPathNa)
+ rows <- collect(df)
+
+ # drop with columns
+
+ expected <- rows[!is.na(rows$name),]
+ actual <- collect(dropna(df, cols = "name"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age),]
+ actual <- collect(dropna(df, cols = "age"))
+ row.names(expected) <- row.names(actual)
+ # identical on two dataframes does not work here. Don't know why.
+ # use identical on all columns as a workaround.
+ expect_true(identical(expected$age, actual$age))
+ expect_true(identical(expected$height, actual$height))
+ expect_true(identical(expected$name, actual$name))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
+ actual <- collect(dropna(df, cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df))
+ expect_true(identical(expected, actual))
+
+ # drop with how
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),]
+ actual <- collect(dropna(df, "all"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df, "any"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
+ actual <- collect(dropna(df, "any", cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) | !is.na(rows$height),]
+ actual <- collect(dropna(df, "all", cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ # drop with threshold
+
+ expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,]
+ actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[as.integer(!is.na(rows$age)) +
+ as.integer(!is.na(rows$height)) +
+ as.integer(!is.na(rows$name)) >= 3,]
+ actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height")))
+ expect_true(identical(expected, actual))
+})
+
+test_that("fillna() on a DataFrame", {
+ df <- jsonFile(sqlContext, jsonPathNa)
+ rows <- collect(df)
+
+ # fill with value
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ expected$height[is.na(expected$height)] <- 50.6
+ actual <- collect(fillna(df, 50.6))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, "unknown"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ actual <- collect(fillna(df, 50.6, "age"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, "unknown", c("age", "name")))
+ expect_true(identical(expected, actual))
+
+ # fill with named list
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ expected$height[is.na(expected$height)] <- 50.6
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown")))
+ expect_true(identical(expected, actual))
+})
+
unlink(parquetPath)
unlink(jsonPath)
+unlink(jsonPathNa)
diff --git a/README.md b/README.md
index 9c09d40e2bda..380422ca00db 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,8 @@
Spark is a fast and general cluster computing system for Big Data. It provides
high-level APIs in Scala, Java, and Python, and an optimized engine that
supports general computation graphs for data analysis. It also supports a
-rich set of higher-level tools including Spark SQL for SQL and structured
-data processing, MLlib for machine learning, GraphX for graph processing,
+rich set of higher-level tools including Spark SQL for SQL and DataFrames,
+MLlib for machine learning, GraphX for graph processing,
and Spark Streaming for stream processing.
@@ -22,7 +22,7 @@ This README file only contains basic setup instructions.
Spark is built using [Apache Maven](http://maven.apache.org/).
To build Spark and its example programs, run:
- mvn -DskipTests clean package
+ build/mvn -DskipTests clean package
(You do not need to do this if you downloaded a pre-built package.)
More detailed documentation is available from the project site, at
@@ -43,7 +43,7 @@ Try the following command, which should return 1000:
Alternatively, if you prefer Python, you can use the Python shell:
./bin/pyspark
-
+
And run the following command, which should also return 1000:
>>> sc.parallelize(range(1000)).count()
@@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example:
will run the Pi example locally.
You can set the MASTER environment variable when running examples to submit
-examples to a cluster. This can be a mesos:// or spark:// URL,
-"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run
-locally with one thread, or "local[N]" to run locally with N threads. You
+examples to a cluster. This can be a mesos:// or spark:// URL,
+"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run
+locally with one thread, or "local[N]" to run locally with N threads. You
can also use an abbreviated class name if the class is in the `examples`
package. For instance:
@@ -75,7 +75,7 @@ can be run using:
./dev/run-tests
-Please see the guidance on how to
+Please see the guidance on how to
[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools).
## A Note About Hadoop Versions
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 626c8577e31f..e9c6d26ccddc 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 1f3dec91314f..ed5c37e595a9 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -40,6 +40,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.scalacheckscalacheck_${scala.binary.version}
diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
index ccb262a4ee02..fb10d734ac74 100644
--- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.bagel
-import org.scalatest.{BeforeAndAfter, FunSuite, Assertions}
+import org.scalatest.{BeforeAndAfter, Assertions}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
@@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
-class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts {
+class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts {
var sc: SparkContext = _
diff --git a/bin/pyspark b/bin/pyspark
index 8acad6113797..f9dbddfa5356 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -17,24 +17,10 @@
# limitations under the License.
#
-# Figure out where Spark is installed
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
source "$SPARK_HOME"/bin/load-spark-env.sh
-
-function usage() {
- if [ -n "$1" ]; then
- echo $1
- fi
- echo "Usage: ./bin/pyspark [options]" 1>&2
- "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- exit $2
-}
-export -f usage
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- usage
-fi
+export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"
# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython`
# executable, while the worker would still be launched using PYSPARK_PYTHON.
@@ -90,11 +76,7 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
export PYTHONHASHSEED=0
- if [[ -n "$PYSPARK_DOC_TEST" ]]; then
- exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
- else
- exec "$PYSPARK_DRIVER_PYTHON" $1
- fi
+ exec "$PYSPARK_DRIVER_PYTHON" -m $1
exit
fi
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index 09b4149c2a43..45e9e3def512 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -21,6 +21,7 @@ rem Figure out where the Spark framework is installed
set SPARK_HOME=%~dp0..
call %SPARK_HOME%\bin\load-spark-env.cmd
+set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options]
rem Figure out which Python to use.
if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
diff --git a/bin/spark-class b/bin/spark-class
index c49d97ce5cf2..2b59e5df5736 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -16,18 +16,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-set -e
# Figure out where Spark is installed
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
. "$SPARK_HOME"/bin/load-spark-env.sh
-if [ -z "$1" ]; then
- echo "Usage: spark-class []" 1>&2
- exit 1
-fi
-
# Find the java binary
if [ -n "${JAVA_HOME}" ]; then
RUNNER="${JAVA_HOME}/bin/java"
@@ -64,24 +58,6 @@ fi
SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}"
-# Verify that versions of java used to build the jars and run Spark are compatible
-if [ -n "$JAVA_HOME" ]; then
- JAR_CMD="$JAVA_HOME/bin/jar"
-else
- JAR_CMD="jar"
-fi
-
-if [ $(command -v "$JAR_CMD") ] ; then
- jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1)
- if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then
- echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2
- echo "This is likely because Spark was compiled with Java 7 and run " 1>&2
- echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2
- echo "or build Spark with Java 6." 1>&2
- exit 1
- fi
-fi
-
LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR"
# Add the launcher build dir to the classpath if requested.
@@ -98,9 +74,4 @@ CMD=()
while IFS= read -d '' -r ARG; do
CMD+=("$ARG")
done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@")
-
-if [ "${CMD[0]}" = "usage" ]; then
- "${CMD[@]}"
-else
- exec "${CMD[@]}"
-fi
+exec "${CMD[@]}"
diff --git a/bin/spark-shell b/bin/spark-shell
index b3761b5e1375..a6dc863d83fc 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -29,20 +29,7 @@ esac
set -o posix
export FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
-
-usage() {
- if [ -n "$1" ]; then
- echo "$1"
- fi
- echo "Usage: ./bin/spark-shell [options]"
- "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- exit "$2"
-}
-export -f usage
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- usage "" 0
-fi
+export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]"
# SPARK-4161: scala does not assume use of the java classpath,
# so we need to add the "-Dscala.usejavacp=true" flag manually. We
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
index 00fd30fa38d3..251309d67f86 100644
--- a/bin/spark-shell2.cmd
+++ b/bin/spark-shell2.cmd
@@ -18,12 +18,7 @@ rem limitations under the License.
rem
set SPARK_HOME=%~dp0..
-
-echo "%*" | findstr " \<--help\> \<-h\>" >nul
-if %ERRORLEVEL% equ 0 (
- call :usage
- exit /b 0
-)
+set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options]
rem SPARK-4161: scala does not assume use of the java classpath,
rem so we need to add the "-Dscala.usejavacp=true" flag manually. We
@@ -37,16 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" (
set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true"
:run_shell
-call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %*
-set SPARK_ERROR_LEVEL=%ERRORLEVEL%
-if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" (
- call :usage
- exit /b 1
-)
-exit /b %SPARK_ERROR_LEVEL%
-
-:usage
-echo %SPARK_LAUNCHER_USAGE_ERROR%
-echo "Usage: .\bin\spark-shell.cmd [options]" >&2
-call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2
-goto :eof
+%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %*
diff --git a/bin/spark-sql b/bin/spark-sql
index ca1729f4cfcb..4ea7bc6e39c0 100755
--- a/bin/spark-sql
+++ b/bin/spark-sql
@@ -17,41 +17,6 @@
# limitations under the License.
#
-#
-# Shell script for starting the Spark SQL CLI
-
-# Enter posix mode for bash
-set -o posix
-
-# NOTE: This exact class name is matched downstream by SparkSubmit.
-# Any changes need to be reflected there.
-export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
-
-# Figure out where Spark is installed
export FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
-
-function usage {
- if [ -n "$1" ]; then
- echo "$1"
- fi
- echo "Usage: ./bin/spark-sql [options] [cli option]"
- pattern="usage"
- pattern+="\|Spark assembly has been built with Hive"
- pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set"
- pattern+="\|Spark Command: "
- pattern+="\|--help"
- pattern+="\|======="
-
- "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- echo
- echo "CLI options:"
- "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2
- exit "$2"
-}
-export -f usage
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- usage "" 0
-fi
-
-exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@"
+export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]"
+exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@"
diff --git a/bin/spark-submit b/bin/spark-submit
index 0e0afe71a0f0..255378b0f077 100755
--- a/bin/spark-submit
+++ b/bin/spark-submit
@@ -22,16 +22,4 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
# disable randomized hash for string in Python 3.3+
export PYTHONHASHSEED=0
-# Only define a usage function if an upstream script hasn't done so.
-if ! type -t usage >/dev/null 2>&1; then
- usage() {
- if [ -n "$1" ]; then
- echo "$1"
- fi
- "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help
- exit "$2"
- }
- export -f usage
-fi
-
exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@"
diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
index d3fc4a5cc3f6..651376e52692 100644
--- a/bin/spark-submit2.cmd
+++ b/bin/spark-submit2.cmd
@@ -24,15 +24,4 @@ rem disable randomized hash for string in Python 3.3+
set PYTHONHASHSEED=0
set CLASS=org.apache.spark.deploy.SparkSubmit
-call %~dp0spark-class2.cmd %CLASS% %*
-set SPARK_ERROR_LEVEL=%ERRORLEVEL%
-if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" (
- call :usage
- exit /b 1
-)
-exit /b %SPARK_ERROR_LEVEL%
-
-:usage
-echo %SPARK_LAUNCHER_USAGE_ERROR%
-call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help
-goto :eof
+%~dp0spark-class2.cmd %CLASS% %*
diff --git a/bin/sparkR b/bin/sparkR
index 8c918e2b09ae..464c29f36942 100755
--- a/bin/sparkR
+++ b/bin/sparkR
@@ -17,23 +17,7 @@
# limitations under the License.
#
-# Figure out where Spark is installed
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
-
source "$SPARK_HOME"/bin/load-spark-env.sh
-
-function usage() {
- if [ -n "$1" ]; then
- echo $1
- fi
- echo "Usage: ./bin/sparkR [options]" 1>&2
- "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- exit $2
-}
-export -f usage
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- usage
-fi
-
+export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]"
exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@"
diff --git a/build/mvn b/build/mvn
index 3561110a4c01..e8364181e823 100755
--- a/build/mvn
+++ b/build/mvn
@@ -69,11 +69,14 @@ install_app() {
# Install maven under the build/ folder
install_mvn() {
+ local MVN_VERSION="3.3.3"
+
install_app \
- "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \
- "apache-maven-3.2.5-bin.tar.gz" \
- "apache-maven-3.2.5/bin/mvn"
- MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn"
+ "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \
+ "apache-maven-${MVN_VERSION}-bin.tar.gz" \
+ "apache-maven-${MVN_VERSION}/bin/mvn"
+
+ MVN_BIN="${_DIR}/apache-maven-${MVN_VERSION}/bin/mvn"
}
# Install zinc under the build/ folder
@@ -105,28 +108,16 @@ install_scala() {
SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar"
}
-# Determines if a given application is already installed. If not, will attempt
-# to install
-## Arg1 - application name
-## Arg2 - Alternate path to local install under build/ dir
-check_and_install_app() {
- # create the local environment variable in uppercase
- local app_bin="`echo $1 | awk '{print toupper(\$0)}'`_BIN"
- # some black magic to set the generated app variable (i.e. MVN_BIN) into the
- # environment
- eval "${app_bin}=`which $1 2>/dev/null`"
-
- if [ -z "`which $1 2>/dev/null`" ]; then
- install_$1
- fi
-}
-
# Setup healthy defaults for the Zinc port if none were provided from
# the environment
ZINC_PORT=${ZINC_PORT:-"3030"}
-# Check and install all applications necessary to build Spark
-check_and_install_app "mvn"
+# Install Maven if necessary
+MVN_BIN="$(command -v mvn)"
+
+if [ ! "$MVN_BIN" ]; then
+ install_mvn
+fi
# Install the proper version of Scala and Zinc for the build
install_zinc
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index 7de0011a48ca..7f17bc7eea4f 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -4,7 +4,7 @@
# divided into instances which correspond to internal components.
# Each instance can be configured to report its metrics to one or more sinks.
# Accepted values for [instance] are "master", "worker", "executor", "driver",
-# and "applications". A wild card "*" can be used as an instance name, in
+# and "applications". A wildcard "*" can be used as an instance name, in
# which case all instances will inherit the supplied property.
#
# Within an instance, a "source" specifies a particular set of grouped metrics.
@@ -32,7 +32,7 @@
# name (see examples below).
# 2. Some sinks involve a polling period. The minimum allowed polling period
# is 1 second.
-# 3. Wild card properties can be overridden by more specific properties.
+# 3. Wildcard properties can be overridden by more specific properties.
# For example, master.sink.console.period takes precedence over
# *.sink.console.period.
# 4. A metrics specific configuration
@@ -47,6 +47,13 @@
# instance master and applications. MetricsServlet may not be configured by self.
#
+## List of available common sources and their properties.
+
+# org.apache.spark.metrics.source.JvmSource
+# Note: Currently, JvmSource is the only available common source
+# to add additionaly to an instance, to enable this,
+# set the "class" option to its fully qulified class name (see examples below)
+
## List of available sinks and their properties.
# org.apache.spark.metrics.sink.ConsoleSink
diff --git a/core/pom.xml b/core/pom.xml
index e58efe495e36..40a64beccdc2 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -338,6 +338,12 @@
org.seleniumhq.seleniumselenium-java
+
+
+ com.google.guava
+ guava
+
+ test
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
new file mode 100644
index 000000000000..d3d6280284be
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.Partitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
+import org.apache.spark.util.Utils;
+
+/**
+ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
+ * writes incoming records to separate files, one file per reduce partition, then concatenates these
+ * per-partition files to form a single output file, regions of which are served to reducers.
+ * Records are not buffered in memory. This is essentially identical to
+ * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format
+ * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}.
+ *
+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it
+ * simultaneously opens separate serializers and file streams for all partitions. As a result,
+ * {@link SortShuffleManager} only selects this write path when
+ *
+ *
no Ordering is specified,
+ *
no Aggregator is specific, and
+ *
the number of partitions is less than
+ * spark.shuffle.sort.bypassMergeThreshold.
+ *
+ *
+ * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was
+ * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details.
+ *
+ * There have been proposals to completely remove this code path; see SPARK-6026 for details.
+ */
+final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter {
+
+ private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
+
+ private final int fileBufferSize;
+ private final boolean transferToEnabled;
+ private final int numPartitions;
+ private final BlockManager blockManager;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final Serializer serializer;
+
+ /** Array of file writers, one for each partition */
+ private BlockObjectWriter[] partitionWriters;
+
+ public BypassMergeSortShuffleWriter(
+ SparkConf conf,
+ BlockManager blockManager,
+ Partitioner partitioner,
+ ShuffleWriteMetrics writeMetrics,
+ Serializer serializer) {
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
+ this.numPartitions = partitioner.numPartitions();
+ this.blockManager = blockManager;
+ this.partitioner = partitioner;
+ this.writeMetrics = writeMetrics;
+ this.serializer = serializer;
+ }
+
+ @Override
+ public void insertAll(Iterator> records) throws IOException {
+ assert (partitionWriters == null);
+ if (!records.hasNext()) {
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new BlockObjectWriter[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2 tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open();
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime);
+
+ while (records.hasNext()) {
+ final Product2 record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
+ }
+
+ for (BlockObjectWriter writer : partitionWriters) {
+ writer.commitAndClose();
+ }
+ }
+
+ @Override
+ public long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException {
+ // Track location of the partition starts in the output file
+ final long[] lengths = new long[numPartitions];
+ if (partitionWriters == null) {
+ // We were passed an empty iterator
+ return lengths;
+ }
+
+ final FileOutputStream out = new FileOutputStream(outputFile, true);
+ final long writeStartTime = System.nanoTime();
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < numPartitions; i++) {
+ final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file());
+ boolean copyThrewException = true;
+ try {
+ lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) {
+ logger.error("Unable to delete file for partition {}", i);
+ }
+ }
+ threwException = false;
+ } finally {
+ Closeables.close(out, threwException);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+ }
+ partitionWriters = null;
+ return lengths;
+ }
+
+ @Override
+ public void stop() throws IOException {
+ if (partitionWriters != null) {
+ try {
+ final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
+ for (BlockObjectWriter writer : partitionWriters) {
+ // This method explicitly does _not_ throw exceptions:
+ writer.revertPartialWritesAndClose();
+ if (!diskBlockManager.getFile(writer.blockId()).delete()) {
+ logger.error("Error while deleting file for block {}", writer.blockId());
+ }
+ }
+ } finally {
+ partitionWriters = null;
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
new file mode 100644
index 000000000000..656ea0401a14
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.collection.Iterator;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.TaskContext;
+import org.apache.spark.storage.BlockId;
+
+/**
+ * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
+ */
+@Private
+public interface SortShuffleFileWriter {
+
+ void insertAll(Iterator> records) throws IOException;
+
+ /**
+ * Write all the data added into this shuffle sorter into a file in the disk store. This is
+ * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+ * binary files if we decided to avoid merge-sorting.
+ *
+ * @param blockId block ID to write to. The index file will be blockId.name + ".index".
+ * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+ */
+ long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException;
+
+ void stop() throws IOException;
+}
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
new file mode 100644
index 000000000000..b146f8a78412
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
@@ -0,0 +1,12 @@
+# Set everything to be logged to the console
+log4j.rootCategory=WARN, console
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+
+# Settings to quiet third party logs that are too verbose
+log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
index 013db8df9b36..0b450dc76bc3 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -50,4 +50,9 @@ $(function() {
$("span.additional-metric-title").click(function() {
$(this).parent().find('input[type="checkbox"]').trigger('click');
});
+
+ // Trigger a double click on the span to show full job description.
+ $(".description-input").dblclick(function() {
+ $(this).removeClass("description-input").addClass("description-input-full");
+ });
});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index e96af8768daa..9fa53baaf421 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -140,7 +140,8 @@ function renderDagViz(forJob) {
// Find cached RDDs and mark them as such
metadataContainer().selectAll(".cached-rdd").each(function(v) {
- var nodeId = VizConstants.nodePrefix + d3.select(this).text();
+ var rddId = d3.select(this).text().trim();
+ var nodeId = VizConstants.nodePrefix + rddId;
svg.selectAll("g." + nodeId).classed("cached", true);
});
@@ -150,7 +151,7 @@ function renderDagViz(forJob) {
/* Render the RDD DAG visualization on the stage page. */
function renderDagVizForStage(svgContainer) {
var metadata = metadataContainer().select(".stage-metadata");
- var dot = metadata.select(".dot-file").text();
+ var dot = metadata.select(".dot-file").text().trim();
var containerId = VizConstants.graphPrefix + metadata.attr("stage-id");
var container = svgContainer.append("g").attr("id", containerId);
renderDot(dot, container, false);
@@ -235,7 +236,7 @@ function renderDagVizForJob(svgContainer) {
// them separately later. Note that we cannot draw them now because we need to
// put these edges in a separate container that is on top of all stage graphs.
metadata.selectAll(".incoming-edge").each(function(v) {
- var edge = d3.select(this).text().split(","); // e.g. 3,4 => [3, 4]
+ var edge = d3.select(this).text().trim().split(","); // e.g. 3,4 => [3, 4]
crossStageEdges.push(edge);
});
});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index e7c1d475d4e5..b1cef4704224 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -135,6 +135,14 @@ pre {
display: block;
}
+.description-input-full {
+ overflow: hidden;
+ text-overflow: ellipsis;
+ width: 100%;
+ white-space: normal;
+ display: block;
+}
+
.stacktrace-details {
max-height: 300px;
overflow-y: auto;
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index b8a5f5016860..ceeb58075d34 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -34,8 +34,8 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
- // When spilling is enabled sorting will happen externally, but not necessarily with an
- // ExternalSorter.
+ // When spilling is enabled sorting will happen externally, but not necessarily with an
+ // ExternalSorter.
private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
@deprecated("use combineValuesByKey with TaskContext argument", "0.9.0")
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 951460475264..49329423dca7 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -101,6 +101,9 @@ private[spark] class ExecutorAllocationManager(
private val executorIdleTimeoutS = conf.getTimeAsSeconds(
"spark.dynamicAllocation.executorIdleTimeout", "60s")
+ private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds(
+ "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s")
+
// During testing, the methods to actually kill and add executors are mocked out
private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
@@ -150,6 +153,13 @@ private[spark] class ExecutorAllocationManager(
// Metric source for ExecutorAllocationManager to expose internal status to MetricsSystem.
val executorAllocationManagerSource = new ExecutorAllocationManagerSource
+ // Whether we are still waiting for the initial set of executors to be allocated.
+ // While this is true, we will not cancel outstanding executor requests. This is
+ // set to false when:
+ // (1) a stage is submitted, or
+ // (2) an executor idle timeout has elapsed.
+ @volatile private var initializing: Boolean = true
+
/**
* Verify that the settings specified through the config are valid.
* If not, throw an appropriate exception.
@@ -240,6 +250,7 @@ private[spark] class ExecutorAllocationManager(
removeTimes.retain { case (executorId, expireTime) =>
val expired = now >= expireTime
if (expired) {
+ initializing = false
removeExecutor(executorId)
}
!expired
@@ -261,15 +272,23 @@ private[spark] class ExecutorAllocationManager(
private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized {
val maxNeeded = maxNumExecutorsNeeded
- if (maxNeeded < numExecutorsTarget) {
+ if (initializing) {
+ // Do not change our target while we are still initializing,
+ // Otherwise the first job may have to ramp up unnecessarily
+ 0
+ } else if (maxNeeded < numExecutorsTarget) {
// The target number exceeds the number we actually need, so stop adding new
// executors and inform the cluster manager to cancel the extra pending requests
val oldNumExecutorsTarget = numExecutorsTarget
numExecutorsTarget = math.max(maxNeeded, minNumExecutors)
- client.requestTotalExecutors(numExecutorsTarget)
numExecutorsToAdd = 1
- logInfo(s"Lowering target number of executors to $numExecutorsTarget because " +
- s"not all requests are actually needed (previously $oldNumExecutorsTarget)")
+
+ // If the new target has not changed, avoid sending a message to the cluster manager
+ if (numExecutorsTarget < oldNumExecutorsTarget) {
+ client.requestTotalExecutors(numExecutorsTarget)
+ logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " +
+ s"$oldNumExecutorsTarget) because not all requested executors are actually needed")
+ }
numExecutorsTarget - oldNumExecutorsTarget
} else if (addTime != NOT_SET && now >= addTime) {
val delta = addExecutors(maxNeeded)
@@ -443,9 +462,23 @@ private[spark] class ExecutorAllocationManager(
private def onExecutorIdle(executorId: String): Unit = synchronized {
if (executorIds.contains(executorId)) {
if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
+ // Note that it is not necessary to query the executors since all the cached
+ // blocks we are concerned with are reported to the driver. Note that this
+ // does not include broadcast blocks.
+ val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId)
+ val now = clock.getTimeMillis()
+ val timeout = {
+ if (hasCachedBlocks) {
+ // Use a different timeout if the executor has cached blocks.
+ now + cachedExecutorIdleTimeoutS * 1000
+ } else {
+ now + executorIdleTimeoutS * 1000
+ }
+ }
+ val realTimeout = if (timeout <= 0) Long.MaxValue else timeout // overflow
+ removeTimes(executorId) = realTimeout
logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
- s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)")
- removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000
+ s"scheduled to run on the executor (to expire in ${(realTimeout - now)/1000} seconds)")
}
} else {
logWarning(s"Attempted to mark unknown executor $executorId idle")
@@ -477,6 +510,7 @@ private[spark] class ExecutorAllocationManager(
private var numRunningTasks: Int = _
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+ initializing = false
val stageId = stageSubmitted.stageInfo.stageId
val numTasks = stageSubmitted.stageInfo.numTasks
allocationManager.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 91f9ef8ce718..48792a958130 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
override def isCompleted: Boolean = jobWaiter.jobFinished
-
+
override def isCancelled: Boolean = _cancelled
override def value: Option[Try[T]] = {
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index f2b024ff6cb6..6909015ff66e 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -29,7 +29,7 @@ import org.apache.spark.util.{ThreadUtils, Utils}
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
- * components to convey liveness or execution information for in-progress tasks. It will also
+ * components to convey liveness or execution information for in-progress tasks. It will also
* expire the hosts that have not heartbeated for more than spark.network.timeout.
*/
private[spark] case class Heartbeat(
@@ -43,8 +43,8 @@ private[spark] case class Heartbeat(
*/
private[spark] case object TaskSchedulerIsSet
-private[spark] case object ExpireDeadHosts
-
+private[spark] case object ExpireDeadHosts
+
private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
/**
@@ -62,18 +62,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
// "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses
// "milliseconds"
- private val slaveTimeoutMs =
+ private val slaveTimeoutMs =
sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s")
- private val executorTimeoutMs =
+ private val executorTimeoutMs =
sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000
-
+
// "spark.network.timeoutInterval" uses "seconds", while
// "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds"
- private val timeoutIntervalMs =
+ private val timeoutIntervalMs =
sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s")
- private val checkTimeoutIntervalMs =
+ private val checkTimeoutIntervalMs =
sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000
-
+
private var timeoutCheckingTask: ScheduledFuture[_] = null
// "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not
@@ -140,7 +140,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
}
}
}
-
+
override def onStop(): Unit = {
if (timeoutCheckingTask != null) {
timeoutCheckingTask.cancel(true)
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index 7e706bcc42f0..7cf7bc0dc681 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -50,8 +50,8 @@ private[spark] class HttpFileServer(
def stop() {
httpServer.stop()
-
- // If we only stop sc, but the driver process still run as a services then we need to delete
+
+ // If we only stop sc, but the driver process still run as a services then we need to delete
// the tmp dir, if not, it will create too many tmp dirs
try {
Utils.deleteRecursively(baseDir)
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index 419d093d5564..7fcb7830e7b0 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -121,13 +121,25 @@ trait Logging {
if (usingLog4j12) {
val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
if (!log4j12Initialized) {
- val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
- System.err.println(s"Spark was unable to load $defaultLogProps")
+ if (Utils.isInInterpreter) {
+ val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties"
+ Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps")
+ System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")")
+ case None =>
+ System.err.println(s"Spark was unable to load $replDefaultLogProps")
+ }
+ } else {
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 018422827e1c..862ffe868f58 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -21,7 +21,7 @@ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.{HashSet, Map}
+import scala.collection.mutable.{HashMap, HashSet, Map}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -284,6 +284,53 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
}
+ /**
+ * Return a list of locations that each have fraction of map output greater than the specified
+ * threshold.
+ *
+ * @param shuffleId id of the shuffle
+ * @param reducerId id of the reduce task
+ * @param numReducers total number of reducers in the shuffle
+ * @param fractionThreshold fraction of total map output size that a location must have
+ * for it to be considered large.
+ *
+ * This method is not thread-safe.
+ */
+ def getLocationsWithLargestOutputs(
+ shuffleId: Int,
+ reducerId: Int,
+ numReducers: Int,
+ fractionThreshold: Double)
+ : Option[Array[BlockManagerId]] = {
+
+ if (mapStatuses.contains(shuffleId)) {
+ val statuses = mapStatuses(shuffleId)
+ if (statuses.nonEmpty) {
+ // HashMap to add up sizes of all blocks at the same location
+ val locs = new HashMap[BlockManagerId, Long]
+ var totalOutputSize = 0L
+ var mapIdx = 0
+ while (mapIdx < statuses.length) {
+ val status = statuses(mapIdx)
+ val blockSize = status.getSizeForBlock(reducerId)
+ if (blockSize > 0) {
+ locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
+ totalOutputSize += blockSize
+ }
+ mapIdx = mapIdx + 1
+ }
+ val topLocs = locs.filter { case (loc, size) =>
+ size.toDouble / totalOutputSize >= fractionThreshold
+ }
+ // Return if we have any locations which satisfy the required threshold
+ if (topLocs.nonEmpty) {
+ return Some(topLocs.map(_._1).toArray)
+ }
+ }
+ }
+ None
+ }
+
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 8aed1e20e068..673ef49e7c1c 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -192,7 +192,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
- private val authOn = sparkConf.getBoolean("spark.authenticate", false)
+ private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false)
// keep spark.ui.acls.enable for backwards compatibility with 1.0
private var aclsOn =
sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false))
@@ -365,10 +365,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
cookie
} else {
// user must have set spark.authenticate.secret config
- sparkConf.getOption("spark.authenticate.secret") match {
+ // For Master/Worker, auth secret is in conf; for Executors, it is in env variable
+ sys.env.get(SecurityManager.ENV_AUTH_SECRET)
+ .orElse(sparkConf.getOption(SecurityManager.SPARK_AUTH_SECRET_CONF)) match {
case Some(value) => value
case None => throw new Exception("Error: a secret key must be specified via the " +
- "spark.authenticate.secret config")
+ SecurityManager.SPARK_AUTH_SECRET_CONF + " config")
}
}
sCookie
@@ -449,3 +451,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
override def getSaslUser(appId: String): String = getSaslUser()
override def getSecretKey(appId: String): String = getSecretKey()
}
+
+private[spark] object SecurityManager {
+
+ val SPARK_AUTH_CONF: String = "spark.authenticate"
+ val SPARK_AUTH_SECRET_CONF: String = "spark.authenticate.secret"
+ // This is used to set auth secret to an executor's env variable. It should have the same
+ // value as SPARK_AUTH_SECERET_CONF set in SparkConf
+ val ENV_AUTH_SECRET = "_SPARK_AUTH_SECRET"
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 4b5bcb54aa87..6cf36fbbd625 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsBytes(key: String, defaultValue: String): Long = {
Utils.byteStringAsBytes(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Kibibytes are assumed.
@@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsKb(key: String, defaultValue: String): Long = {
Utils.byteStringAsKb(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Mebibytes are assumed.
@@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsMb(key: String, defaultValue: String): Long = {
Utils.byteStringAsMb(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Gibibytes are assumed.
@@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsGb(key: String, defaultValue: String): Long = {
Utils.byteStringAsGb(get(key, defaultValue))
}
-
+
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
Option(settings.get(key)).orElse(getDeprecatedConfig(key, this))
@@ -480,7 +480,7 @@ private[spark] object SparkConf extends Logging {
"spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " +
"are no longer accepted. To specify the equivalent now, one may use '64k'.")
)
-
+
Map(configs.map { cfg => (cfg.key -> cfg) } : _*)
}
@@ -508,7 +508,7 @@ private[spark] object SparkConf extends Logging {
"spark.reducer.maxSizeInFlight" -> Seq(
AlternateConfig("spark.reducer.maxMbInFlight", "1.4")),
"spark.kryoserializer.buffer" ->
- Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4",
+ Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4",
translation = s => s"${(s.toDouble * 1000).toInt}k")),
"spark.kryoserializer.buffer.max" -> Seq(
AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")),
@@ -557,7 +557,7 @@ private[spark] object SparkConf extends Logging {
def isExecutorStartupConf(name: String): Boolean = {
isAkkaConf(name) ||
name.startsWith("spark.akka") ||
- name.startsWith("spark.auth") ||
+ (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) ||
name.startsWith("spark.ssl") ||
isSparkPortConf(name)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index a18595408952..b0665570e268 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -20,6 +20,8 @@ package org.apache.spark
import java.io.File
import java.net.Socket
+import akka.actor.ActorSystem
+
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.util.Properties
@@ -75,7 +77,8 @@ class SparkEnv (
val conf: SparkConf) extends Logging {
// TODO Remove actorSystem
- val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
+ @deprecated("Actor system is no longer supported as of 1.4")
+ val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index fe6320b504e1..a1ebbecf93b7 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -51,7 +51,7 @@ private[spark] object TestUtils {
classpathUrls: Seq[URL] = Seq()): URL = {
val tempDir = Utils.createTempDir()
val files1 = for (name <- classNames) yield {
- createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
+ createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
}
val files2 = for ((childName, baseName) <- classNamesWithBase) yield {
createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 61af867b11b9..a650df605b92 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double])
*/
def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
sample(withReplacement, fraction, Utils.random.nextLong)
-
+
/**
* Return a sampled subset of this RDD.
*/
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index db4e996feb31..ed312770ee13 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
/**
* Return a sampled subset of this RDD.
- *
+ *
* @param withReplacement can elements be sampled multiple times (replaced when sampled out)
* @param fraction expected size of the sample as a fraction of this RDD's size
* without replacement: probability that each element is chosen; fraction must be [0, 1]
@@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
*/
def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
sample(withReplacement, fraction, Utils.random.nextLong)
-
+
/**
* Return a sampled subset of this RDD.
- *
+ *
* @param withReplacement can elements be sampled multiple times (replaced when sampled out)
* @param fraction expected size of the sample as a fraction of this RDD's size
* without replacement: probability that each element is chosen; fraction must be [0, 1]
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index b8e15f38a20d..c95615a5a930 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
@deprecated("Use partitions() instead.", "1.1.0")
def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
-
+
/** Set of partitions in this RDD. */
def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
+ /** The partitioner of this RDD. */
+ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner)
+
/** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
@@ -492,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
- def takeSample(withReplacement: Boolean, num: Int): JList[T] =
+ def takeSample(withReplacement: Boolean, num: Int): JList[T] =
takeSample(withReplacement, num, Utils.random.nextLong)
-
+
def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index a77bf42ce1d3..0103f6c6ab67 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -425,6 +425,11 @@ private[spark] object PythonRDD extends Logging {
iter.foreach(write)
}
+ /** Create an RDD that has no partitions or elements. */
+ def emptyRDD[T](sc: JavaSparkContext): JavaRDD[T] = {
+ sc.emptyRDD[T]
+ }
+
/**
* Create an RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]],
* key and value class.
@@ -797,10 +802,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
- /**
+ /**
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
* by the DAGScheduler's single-threaded actor anyway.
- */
+ */
@transient var socket: Socket = _
def openSocket(): Socket = synchronized {
@@ -843,6 +848,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
* An Wrapper for Python Broadcast, which is written into disk by Python. It also will
* write the data into disk after deserialization, then Python can read it from disks.
*/
+// scalastyle:off no.finalize
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable {
/**
@@ -884,3 +890,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
}
}
+// scalastyle:on no.finalize
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 0a91977928ce..1a5f2bca26c2 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -29,7 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkConf}
/**
* Netty-based backend server that is used to communicate between R and Java.
@@ -41,14 +41,15 @@ private[spark] class RBackend {
private[this] var bossGroup: EventLoopGroup = null
def init(): Int = {
- bossGroup = new NioEventLoopGroup(2)
+ val conf = new SparkConf()
+ bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
-
+
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(classOf[NioServerSocketChannel])
-
+
bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
def initChannel(ch: SocketChannel): Unit = {
ch.pipeline()
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 026a1b938035..2e86984c66b3 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -77,7 +77,7 @@ private[r] class RBackendHandler(server: RBackend)
val reply = bos.toByteArray
ctx.write(reply)
}
-
+
override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
ctx.flush()
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index e020458888e4..4dfa7325934f 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -355,7 +355,6 @@ private[r] object RRDD {
val sparkConf = new SparkConf().setAppName(appName)
.setSparkHome(sparkHome)
- .setJars(jars)
// Override `master` if we have a user-specified value
if (master != "") {
@@ -373,7 +372,11 @@ private[r] object RRDD {
sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String])
}
- new JavaSparkContext(sparkConf)
+ val jsc = new JavaSparkContext(sparkConf)
+ jars.foreach { jar =>
+ jsc.addJar(jar)
+ }
+ jsc
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 371dfe454d1a..56adc857d4ce 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -18,7 +18,7 @@
package org.apache.spark.api.r
import java.io.{DataInputStream, DataOutputStream}
-import java.sql.{Date, Time}
+import java.sql.{Timestamp, Date, Time}
import scala.collection.JavaConversions._
@@ -107,9 +107,12 @@ private[spark] object SerDe {
Date.valueOf(readString(in))
}
- def readTime(in: DataInputStream): Time = {
- val t = in.readDouble()
- new Time((t * 1000L).toLong)
+ def readTime(in: DataInputStream): Timestamp = {
+ val seconds = in.readDouble()
+ val sec = Math.floor(seconds).toLong
+ val t = new Timestamp(sec * 1000L)
+ t.setNanos(((seconds - sec) * 1e9).toInt)
+ t
}
def readBytesArr(in: DataInputStream): Array[Array[Byte]] = {
@@ -157,9 +160,11 @@ private[spark] object SerDe {
val keysLen = readInt(in)
val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
- val valuesType = readObjectType(in)
val valuesLen = readInt(in)
- val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType))
+ val values = (0 until valuesLen).map(_ => {
+ val valueType = readObjectType(in)
+ readTypedObject(in, valueType)
+ })
mapAsJavaMap(keys.zip(values).toMap)
} else {
new java.util.HashMap[Object, Object]()
@@ -225,6 +230,9 @@ private[spark] object SerDe {
case "java.sql.Time" =>
writeType(dos, "time")
writeTime(dos, value.asInstanceOf[Time])
+ case "java.sql.Timestamp" =>
+ writeType(dos, "time")
+ writeTime(dos, value.asInstanceOf[Timestamp])
case "[B" =>
writeType(dos, "raw")
writeBytes(dos, value.asInstanceOf[Array[Byte]])
@@ -287,6 +295,9 @@ private[spark] object SerDe {
out.writeDouble(value.getTime.toDouble / 1000.0)
}
+ def writeTime(out: DataOutputStream, value: Timestamp): Unit = {
+ out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9)
+ }
// NOTE: Only works for ASCII right now
def writeString(out: DataOutputStream, value: String): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 92bb5059a031..cfcc6d355801 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -82,13 +82,13 @@ object SparkSubmit {
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
// Exposed for testing
- private[spark] var exitFn: () => Unit = () => System.exit(1)
+ private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode)
private[spark] var printStream: PrintStream = System.err
private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str)
private[spark] def printErrorAndExit(str: String): Unit = {
printStream.println("Error: " + str)
printStream.println("Run with --help for usage help or --verbose for debug output")
- exitFn()
+ exitFn(1)
}
private[spark] def printVersionAndExit(): Unit = {
printStream.println("""Welcome to
@@ -99,7 +99,7 @@ object SparkSubmit {
/_/
""".format(SPARK_VERSION))
printStream.println("Type --help for more information.")
- exitFn()
+ exitFn(0)
}
def main(args: Array[String]): Unit = {
@@ -160,7 +160,7 @@ object SparkSubmit {
// detect exceptions with empty stack traces here, and treat them differently.
if (e.getStackTrace().length == 0) {
printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
- exitFn()
+ exitFn(1)
} else {
throw e
}
@@ -324,55 +324,20 @@ object SparkSubmit {
// Usage: PythonAppRunner [app arguments]
args.mainClass = "org.apache.spark.deploy.PythonRunner"
args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs
- args.files = mergeFileLists(args.files, args.primaryResource)
+ if (clusterManager != YARN) {
+ // The YARN backend distributes the primary file differently, so don't merge it.
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ }
+ }
+ if (clusterManager != YARN) {
+ // The YARN backend handles python files differently, so don't merge the lists.
+ args.files = mergeFileLists(args.files, args.pyFiles)
}
- args.files = mergeFileLists(args.files, args.pyFiles)
if (args.pyFiles != null) {
sysProps("spark.submit.pyFiles") = args.pyFiles
}
}
- // In yarn mode for a python app, add pyspark archives to files
- // that can be distributed with the job
- if (args.isPython && clusterManager == YARN) {
- var pyArchives: String = null
- val pyArchivesEnvOpt = sys.env.get("PYSPARK_ARCHIVES_PATH")
- if (pyArchivesEnvOpt.isDefined) {
- pyArchives = pyArchivesEnvOpt.get
- } else {
- if (!sys.env.contains("SPARK_HOME")) {
- printErrorAndExit("SPARK_HOME does not exist for python application in yarn mode.")
- }
- val pythonPath = new ArrayBuffer[String]
- for (sparkHome <- sys.env.get("SPARK_HOME")) {
- val pyLibPath = Seq(sparkHome, "python", "lib").mkString(File.separator)
- val pyArchivesFile = new File(pyLibPath, "pyspark.zip")
- if (!pyArchivesFile.exists()) {
- printErrorAndExit("pyspark.zip does not exist for python application in yarn mode.")
- }
- val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip")
- if (!py4jFile.exists()) {
- printErrorAndExit("py4j-0.8.2.1-src.zip does not exist for python application " +
- "in yarn mode.")
- }
- pythonPath += pyArchivesFile.getAbsolutePath()
- pythonPath += py4jFile.getAbsolutePath()
- }
- pyArchives = pythonPath.mkString(",")
- }
-
- pyArchives = pyArchives.split(",").map { localPath =>
- val localURI = Utils.resolveURI(localPath)
- if (localURI.getScheme != "local") {
- args.files = mergeFileLists(args.files, localURI.toString)
- new Path(localPath).getName
- } else {
- localURI.getPath
- }
- }.mkString(File.pathSeparator)
- sysProps("spark.submit.pyArchives") = pyArchives
- }
-
// If we're running a R app, set the main class to our specific R runner
if (args.isR && deployMode == CLIENT) {
if (args.primaryResource == SPARKR_SHELL) {
@@ -386,19 +351,10 @@ object SparkSubmit {
}
}
- if (isYarnCluster) {
- // In yarn-cluster mode for a python app, add primary resource and pyFiles to files
- // that can be distributed with the job
- if (args.isPython) {
- args.files = mergeFileLists(args.files, args.primaryResource)
- args.files = mergeFileLists(args.files, args.pyFiles)
- }
-
+ if (isYarnCluster && args.isR) {
// In yarn-cluster mode for a R app, add primary resource to files
// that can be distributed with the job
- if (args.isR) {
- args.files = mergeFileLists(args.files, args.primaryResource)
- }
+ args.files = mergeFileLists(args.files, args.primaryResource)
}
// Special flag to avoid deprecation warnings at the client
@@ -425,9 +381,10 @@ object SparkSubmit {
// Yarn client only
OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"),
OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"),
- OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"),
OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"),
OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"),
+ OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"),
+ OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"),
// Yarn cluster only
OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"),
@@ -440,13 +397,11 @@ object SparkSubmit {
OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"),
OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"),
OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"),
-
- // Yarn client or cluster
- OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"),
- OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"),
+ OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"),
+ OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"),
// Other options
- OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES,
+ OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.cores"),
OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.memory"),
@@ -516,17 +471,18 @@ object SparkSubmit {
}
}
+ // Let YARN know it's a pyspark app, so it distributes needed libraries.
+ if (clusterManager == YARN && args.isPython) {
+ sysProps.put("spark.yarn.isPython", "true")
+ }
+
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
if (isYarnCluster) {
childMainClass = "org.apache.spark.deploy.yarn.Client"
if (args.isPython) {
- val mainPyFile = new Path(args.primaryResource).getName
- childArgs += ("--primary-py-file", mainPyFile)
+ childArgs += ("--primary-py-file", args.primaryResource)
if (args.pyFiles != null) {
- // These files will be distributed to each machine's working directory, so strip the
- // path prefix
- val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",")
- childArgs += ("--py-files", pyFilesNames)
+ childArgs += ("--py-files", args.pyFiles)
}
childArgs += ("--class", "org.apache.spark.deploy.PythonRunner")
} else if (args.isR) {
@@ -700,7 +656,7 @@ object SparkSubmit {
/**
* Return whether the given main class represents a sql shell.
*/
- private def isSqlShell(mainClass: String): Boolean = {
+ private[deploy] def isSqlShell(mainClass: String): Boolean = {
mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
}
@@ -869,18 +825,14 @@ private[spark] object SparkSubmitUtils {
md.addDependency(dd)
}
}
-
+
/** Add exclusion rules for dependencies already included in the spark-assembly */
def addExclusionRules(
ivySettings: IvySettings,
ivyConfName: String,
md: DefaultModuleDescriptor): Unit = {
// Add scala exclusion rule
- val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*")
- val scalaDependencyExcludeRule =
- new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null)
- scalaDependencyExcludeRule.addConfiguration(ivyConfName)
- md.addExcludeRule(scalaDependencyExcludeRule)
+ md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName))
// We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and
// other spark-streaming utility components. Underscore is there to differentiate between
@@ -889,13 +841,8 @@ private[spark] object SparkSubmitUtils {
"sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_")
components.foreach { comp =>
- val sparkArtifacts =
- new ArtifactId(new ModuleId("org.apache.spark", s"spark-$comp*"), "*", "*", "*")
- val sparkDependencyExcludeRule =
- new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
- sparkDependencyExcludeRule.addConfiguration(ivyConfName)
-
- md.addExcludeRule(sparkDependencyExcludeRule)
+ md.addExcludeRule(createExclusion(s"org.apache.spark:spark-$comp*:*", ivySettings,
+ ivyConfName))
}
}
@@ -908,6 +855,7 @@ private[spark] object SparkSubmitUtils {
* @param coordinates Comma-delimited string of maven coordinates
* @param remoteRepos Comma-delimited string of remote repositories other than maven central
* @param ivyPath The path to the local ivy repository
+ * @param exclusions Exclusions to apply when resolving transitive dependencies
* @return The comma-delimited path to the jars of the given maven artifacts including their
* transitive dependencies
*/
@@ -915,6 +863,7 @@ private[spark] object SparkSubmitUtils {
coordinates: String,
remoteRepos: Option[String],
ivyPath: Option[String],
+ exclusions: Seq[String] = Nil,
isTest: Boolean = false): String = {
if (coordinates == null || coordinates.trim.isEmpty) {
""
@@ -972,6 +921,10 @@ private[spark] object SparkSubmitUtils {
// add all supplied maven artifacts as dependencies
addDependenciesToIvy(md, artifacts, ivyConfName)
+ exclusions.foreach { e =>
+ md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName))
+ }
+
// resolve dependencies
val rr: ResolveReport = ivy.resolve(md, resolveOptions)
if (rr.hasError) {
@@ -988,6 +941,18 @@ private[spark] object SparkSubmitUtils {
}
}
}
+
+ private def createExclusion(
+ coords: String,
+ ivySettings: IvySettings,
+ ivyConfName: String): ExcludeRule = {
+ val c = extractMavenCoordinates(coords)(0)
+ val id = new ArtifactId(new ModuleId(c.groupId, c.artifactId), "*", "*", "*")
+ val rule = new DefaultExcludeRule(id, ivySettings.getMatcher("glob"), null)
+ rule.addConfiguration(ivyConfName)
+ rule
+ }
+
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index c0e4c771908b..b7429a901e16 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,12 +17,15 @@
package org.apache.spark.deploy
+import java.io.{ByteArrayOutputStream, PrintStream}
+import java.lang.reflect.InvocationTargetException
import java.net.URI
import java.util.{List => JList}
import java.util.jar.JarFile
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.io.Source
import org.apache.spark.deploy.SparkSubmitAction._
import org.apache.spark.launcher.SparkSubmitArgumentsParser
@@ -169,6 +172,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
numExecutors = Option(numExecutors)
.getOrElse(sparkProperties.get("spark.executor.instances").orNull)
+ keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull
+ principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && !isR && primaryResource != null) {
@@ -410,6 +415,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
case VERSION =>
SparkSubmit.printVersionAndExit()
+ case USAGE_ERROR =>
+ printUsageAndExit(1)
+
case _ =>
throw new IllegalArgumentException(s"Unexpected argument '$opt'.")
}
@@ -447,11 +455,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
}
- outStream.println(
+ val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse(
"""Usage: spark-submit [options] [app arguments]
|Usage: spark-submit --kill [submission ID] --master [spark://...]
- |Usage: spark-submit --status [submission ID] --master [spark://...]
- |
+ |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin)
+ outStream.println(command)
+
+ outStream.println(
+ """
|Options:
| --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
| --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or
@@ -523,6 +534,65 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| delegation tokens periodically.
""".stripMargin
)
- SparkSubmit.exitFn()
+
+ if (SparkSubmit.isSqlShell(mainClass)) {
+ outStream.println("CLI options:")
+ outStream.println(getSqlShellOptions())
+ }
+
+ SparkSubmit.exitFn(exitCode)
}
+
+ /**
+ * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter
+ * the results to remove unwanted lines.
+ *
+ * Since the CLI will call `System.exit()`, we install a security manager to prevent that call
+ * from working, and restore the original one afterwards.
+ */
+ private def getSqlShellOptions(): String = {
+ val currentOut = System.out
+ val currentErr = System.err
+ val currentSm = System.getSecurityManager()
+ try {
+ val out = new ByteArrayOutputStream()
+ val stream = new PrintStream(out)
+ System.setOut(stream)
+ System.setErr(stream)
+
+ val sm = new SecurityManager() {
+ override def checkExit(status: Int): Unit = {
+ throw new SecurityException()
+ }
+
+ override def checkPermission(perm: java.security.Permission): Unit = {}
+ }
+ System.setSecurityManager(sm)
+
+ try {
+ Class.forName(mainClass).getMethod("main", classOf[Array[String]])
+ .invoke(null, Array(HELP))
+ } catch {
+ case e: InvocationTargetException =>
+ // Ignore SecurityException, since we throw it above.
+ if (!e.getCause().isInstanceOf[SecurityException]) {
+ throw e
+ }
+ }
+
+ stream.flush()
+
+ // Get the output and discard any unnecessary lines from it.
+ Source.fromString(new String(out.toByteArray())).getLines
+ .filter { line =>
+ !line.startsWith("log4j") && !line.startsWith("usage")
+ }
+ .mkString("\n")
+ } finally {
+ System.setSecurityManager(currentSm)
+ System.setOut(currentOut)
+ System.setErr(currentErr)
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
index 298a8201960d..5f5e0fe1c34d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
@@ -17,6 +17,9 @@
package org.apache.spark.deploy.history
+import java.util.zip.ZipOutputStream
+
+import org.apache.spark.SparkException
import org.apache.spark.ui.SparkUI
private[spark] case class ApplicationAttemptInfo(
@@ -62,4 +65,12 @@ private[history] abstract class ApplicationHistoryProvider {
*/
def getConfig(): Map[String, String] = Map()
+ /**
+ * Writes out the event logs to the output stream provided. The logs will be compressed into a
+ * single zip file and written out.
+ * @throws SparkException if the logs for the app id cannot be found.
+ */
+ @throws(classOf[SparkException])
+ def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 45c2be34c868..db383b9823d3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -17,16 +17,18 @@
package org.apache.spark.deploy.history
-import java.io.{BufferedInputStream, FileNotFoundException, IOException, InputStream}
+import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream}
import java.util.concurrent.{ExecutorService, Executors, TimeUnit}
+import java.util.zip.{ZipEntry, ZipOutputStream}
import scala.collection.mutable
+import com.google.common.io.ByteStreams
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.fs.permission.AccessControlException
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
import org.apache.spark.scheduler._
@@ -59,7 +61,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
.map { d => Utils.resolveURI(d).toString }
.getOrElse(DEFAULT_LOG_DIR)
- private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf))
+ private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+ private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf)
// Used by check event thread and clean log thread.
// Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs
@@ -157,7 +160,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
replayBus.addListener(appListener)
val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus)
- ui.setAppName(s"${appInfo.name} ($appId)")
+ appInfo.foreach { app => ui.setAppName(s"${app.name} ($appId)") }
val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false)
ui.getSecurityManager.setAcls(uiAclsEnabled)
@@ -219,6 +222,58 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
}
}
+ override def writeEventLogs(
+ appId: String,
+ attemptId: Option[String],
+ zipStream: ZipOutputStream): Unit = {
+
+ /**
+ * This method compresses the files passed in, and writes the compressed data out into the
+ * [[OutputStream]] passed in. Each file is written as a new [[ZipEntry]] with its name being
+ * the name of the file being compressed.
+ */
+ def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = {
+ val fs = FileSystem.get(hadoopConf)
+ val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer
+ try {
+ outputStream.putNextEntry(new ZipEntry(entryName))
+ ByteStreams.copy(inputStream, outputStream)
+ outputStream.closeEntry()
+ } finally {
+ inputStream.close()
+ }
+ }
+
+ applications.get(appId) match {
+ case Some(appInfo) =>
+ try {
+ // If no attempt is specified, or there is no attemptId for attempts, return all attempts
+ appInfo.attempts.filter { attempt =>
+ attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get
+ }.foreach { attempt =>
+ val logPath = new Path(logDir, attempt.logPath)
+ // If this is a legacy directory, then add the directory to the zipStream and add
+ // each file to that directory.
+ if (isLegacyLogDirectory(fs.getFileStatus(logPath))) {
+ val files = fs.listStatus(logPath)
+ zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/"))
+ zipStream.closeEntry()
+ files.foreach { file =>
+ val path = file.getPath
+ zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream)
+ }
+ } else {
+ zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream)
+ }
+ }
+ } finally {
+ zipStream.close()
+ }
+ case None => throw new SparkException(s"Logs for $appId not found.")
+ }
+ }
+
+
/**
* Replay the log files in the list and merge the list of old applications with new ones
*/
@@ -227,8 +282,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
val newAttempts = logs.flatMap { fileStatus =>
try {
val res = replay(fileStatus, bus)
- logInfo(s"Application log ${res.logPath} loaded successfully.")
- Some(res)
+ res match {
+ case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.")
+ case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " +
+ "The application may have not started.")
+ }
+ res
} catch {
case e: Exception =>
logError(
@@ -374,9 +433,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
/**
* Replays the events in the specified log file and returns information about the associated
- * application.
+ * application. Return `None` if the application ID cannot be located.
*/
- private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationAttemptInfo = {
+ private def replay(
+ eventLog: FileStatus,
+ bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = {
val logPath = eventLog.getPath()
logInfo(s"Replaying log path: $logPath")
val logInput =
@@ -390,16 +451,18 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
val appCompleted = isApplicationCompleted(eventLog)
bus.addListener(appListener)
bus.replay(logInput, logPath.toString, !appCompleted)
- new FsApplicationAttemptInfo(
- logPath.getName(),
- appListener.appName.getOrElse(NOT_STARTED),
- appListener.appId.getOrElse(logPath.getName()),
- appListener.appAttemptId,
- appListener.startTime.getOrElse(-1L),
- appListener.endTime.getOrElse(-1L),
- getModificationTime(eventLog).get,
- appListener.sparkUser.getOrElse(NOT_STARTED),
- appCompleted)
+ appListener.appId.map { appId =>
+ new FsApplicationAttemptInfo(
+ logPath.getName(),
+ appListener.appName.getOrElse(NOT_STARTED),
+ appId,
+ appListener.appAttemptId,
+ appListener.startTime.getOrElse(-1L),
+ appListener.endTime.getOrElse(-1L),
+ getModificationTime(eventLog).get,
+ appListener.sparkUser.getOrElse(NOT_STARTED),
+ appCompleted)
+ }
} finally {
logInput.close()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 5a0eb585a904..10638afb7490 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.history
import java.util.NoSuchElementException
+import java.util.zip.ZipOutputStream
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import com.google.common.cache._
@@ -173,6 +174,13 @@ class HistoryServer(
getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo)
}
+ override def writeEventLogs(
+ appId: String,
+ attemptId: Option[String],
+ zipStream: ZipOutputStream): Unit = {
+ provider.writeEventLogs(appId, attemptId, zipStream)
+ }
+
/**
* Returns the provider configuration to show in the listing page.
*
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index a2a97a7877ce..4692d22651c9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -23,7 +23,7 @@ import org.apache.spark.util.Utils
/**
* Command-line parser for the master.
*/
-private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
+private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
extends Logging {
private var propertiesFile: String = null
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 80db6d474b5c..328d95a7a0c6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil
private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
extends PersistenceEngine
with Logging {
-
+
private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 756927682cd2..6a7c74020bac 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -75,6 +75,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory")
val workers = state.workers.sortBy(_.id)
+ val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
@@ -108,12 +109,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}.getOrElse { Seq.empty }
}
-
Workers: {state.workers.size}
-
Cores: {state.workers.map(_.cores).sum} Total,
- {state.workers.map(_.coresUsed).sum} Used
-
Memory:
- {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total,
- {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
+
Alive Workers: {aliveWorkers.size}
+
Cores in use: {aliveWorkers.map(_.cores).sum} Total,
+ {aliveWorkers.map(_.coresUsed).sum} Used
+
Memory in use:
+ {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total,
+ {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
diff --git a/docs/building-spark.md b/docs/building-spark.md
index b2649d1ee2a5..2128fdffecc0 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -7,11 +7,7 @@ redirect_from: "building-with-maven.html"
* This will become a table of contents (this text will be scraped).
{:toc}
-Building Spark using Maven requires Maven 3.0.4 or newer and Java 6+.
-
-**Note:** Building Spark with Java 7 or later can create JAR files that may not be
-readable with early versions of Java 6, due to the large number of files in the JAR
-archive. Build with Java 6 if this is an issue for your deployment.
+Building Spark using Maven requires Maven 3.0.4 or newer and Java 7+.
# Building with `build/mvn`
@@ -80,6 +76,7 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro
2.2.x
hadoop-2.2
2.3.x
hadoop-2.3
2.4.x
hadoop-2.4
+
2.6.x and later 2.x
hadoop-2.6
@@ -130,9 +127,7 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop
dev/change-version-to-2.11.sh
mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package
-Scala 2.11 support in Spark does not support a few features due to dependencies
-which are themselves not Scala 2.11 ready. Specifically, Spark's external
-Kafka library and JDBC component are not yet supported in Scala 2.11 builds.
+Spark does not yet support its JDBC component for Scala 2.11.
# Spark Tests in Maven
diff --git a/docs/configuration.md b/docs/configuration.md
index 30508a617fdd..affcd21514d8 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1,4 +1,4 @@
---
+---
layout: global
displayTitle: Spark Configuration
title: Configuration
@@ -334,7 +334,7 @@ Apart from these, the following properties are also available, and may be useful
Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`,
or it will be displayed before the driver exiting. It also can be dumped into disk by
- `sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
+ `sc.dump_profiles(path)`. If some of the profile results had been displayed manually,
they will not be displayed automatically before driver exiting.
By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by
@@ -618,7 +618,7 @@ Apart from these, the following properties are also available, and may be useful
spark.kryo.referenceTracking
-
true
+
true (false when using Spark SQL Thrift Server)
Whether to track references to the same object when serializing data with Kryo, which is
necessary if your object graphs have loops and useful for efficiency if they contain multiple
@@ -679,7 +679,10 @@ Apart from these, the following properties are also available, and may be useful
spark.serializer
-
org.apache.spark.serializer. JavaSerializer
+
+ org.apache.spark.serializer. JavaSerializer (org.apache.spark.serializer.
+ KryoSerializer when using Spark SQL Thrift Server)
+
Class to use for serializing objects that will be sent over the network or need to be cached
in serialized form. The default of Java serialization works with any Serializable Java object
@@ -1201,6 +1204,15 @@ Apart from these, the following properties are also available, and may be useful
description.
+
+
spark.dynamicAllocation.cachedExecutorIdleTimeout
+
2 * executorIdleTimeout
+
+ If dynamic allocation is enabled and an executor which has cached data blocks has been idle for more than this duration,
+ the executor will be removed. For more details, see this
+ description.
+
+
spark.dynamicAllocation.initialExecutors
spark.dynamicAllocation.minExecutors
@@ -1483,6 +1495,18 @@ Apart from these, the following properties are also available, and may be useful
+#### SparkR
+
+
Property Name
Default
Meaning
+
+
spark.r.numRBackendThreads
+
2
+
+ Number of threads used by RBackend to handle RPC calls from SparkR package.
+
+
+
+
#### Cluster Managers
Each cluster manager in Spark has additional configuration options. Configurations
can be found on the pages for each mode:
diff --git a/docs/hadoop-provided.md b/docs/hadoop-provided.md
new file mode 100644
index 000000000000..bbd26b343e2e
--- /dev/null
+++ b/docs/hadoop-provided.md
@@ -0,0 +1,26 @@
+---
+layout: global
+displayTitle: Using Spark's "Hadoop Free" Build
+title: Using Spark's "Hadoop Free" Build
+---
+
+Spark uses Hadoop client libraries for HDFS and YARN. Starting in version Spark 1.4, the project packages "Hadoop free" builds that lets you more easily connect a single Spark binary to any Hadoop version. To use these builds, you need to modify `SPARK_DIST_CLASSPATH` to include Hadoop's package jars. The most convenient place to do this is by adding an entry in `conf/spark-env.sh`.
+
+This page describes how to connect Spark to Hadoop for different types of distributions.
+
+# Apache Hadoop
+For Apache distributions, you can use Hadoop's 'classpath' command. For instance:
+
+{% highlight bash %}
+### in conf/spark-env.sh ###
+
+# If 'hadoop' binary is on your PATH
+export SPARK_DIST_CLASSPATH=$(hadoop classpath)
+
+# With explicit path to 'hadoop' binary
+export SPARK_DIST_CLASSPATH=$(/path/to/hadoop/bin/hadoop classpath)
+
+# Passing a Hadoop configuration directory
+export SPARK_DIST_CLASSPATH=$(hadoop --config /path/to/configs classpath)
+
+{% endhighlight %}
diff --git a/docs/index.md b/docs/index.md
index 5ef6d983c45a..d85cf12defef 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -12,15 +12,19 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog
# Downloading
-Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. The downloads page
-contains Spark packages for many popular HDFS versions. If you'd like to build Spark from
-scratch, visit [Building Spark](building-spark.html).
+Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions.
+Users can also download a "Hadoop free" binary and run Spark with any Hadoop version
+[by augmenting Spark's classpath](hadoop-provided.html).
+
+If you'd like to build Spark from
+source, visit [Building Spark](building-spark.html).
+
Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy to run
locally on one machine --- all you need is to have `java` installed on your system `PATH`,
or the `JAVA_HOME` environment variable pointing to a Java installation.
-Spark runs on Java 6+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses
+Spark runs on Java 7+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses
Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version
({{site.SCALA_BINARY_VERSION}}.x).
@@ -54,7 +58,7 @@ Example applications are also provided in Python. For example,
./bin/spark-submit examples/src/main/python/pi.py 10
-Spark also provides an experimental R API since 1.4 (only DataFrames APIs included).
+Spark also provides an experimental [R API](sparkr.html) since 1.4 (only DataFrames APIs included).
To run Spark interactively in a R interpreter, use `bin/sparkR`:
./bin/sparkR --master local[2]
diff --git a/docs/ml-features.md b/docs/ml-features.md
index d7851a55fabf..f88c0248c1a8 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3):
+## StringIndexer
+
+`StringIndexer` encodes a string column of labels to a column of label indices.
+The indices are in `[0, numLabels)`, ordered by label frequencies.
+So the most frequent label gets index `0`.
+If the input column is numeric, we cast it to string and index the string values.
+
+**Examples**
+
+Assume that we have the following DataFrame with columns `id` and `category`:
+
+~~~~
+ id | category
+----|----------
+ 0 | a
+ 1 | b
+ 2 | c
+ 3 | a
+ 4 | a
+ 5 | c
+~~~~
+
+`category` is a string column with three labels: "a", "b", and "c".
+Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output
+column, we should get the following:
+
+~~~~
+ id | category | categoryIndex
+----|----------|---------------
+ 0 | a | 0.0
+ 1 | b | 2.0
+ 2 | c | 1.0
+ 3 | a | 0.0
+ 4 | a | 0.0
+ 5 | c | 1.0
+~~~~
+
+"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
+index `2`.
+
+
+
## OneHotEncoder
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
@@ -876,5 +992,207 @@ bucketedData = bucketizer.transform(dataFrame)
+## ElementwiseProduct
+
+ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector.
+
+`\[ \begin{pmatrix}
+v_1 \\
+\vdots \\
+v_N
+\end{pmatrix} \circ \begin{pmatrix}
+ w_1 \\
+ \vdots \\
+ w_N
+ \end{pmatrix}
+= \begin{pmatrix}
+ v_1 w_1 \\
+ \vdots \\
+ v_N w_N
+ \end{pmatrix}
+\]`
+
+[`ElementwiseProduct`](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) takes the following parameter:
+
+* `scalingVec`: the transforming vector.
+
+This example below demonstrates how to transform vectors using a transforming vector value.
+
+
+
+{% highlight scala %}
+import org.apache.spark.ml.feature.ElementwiseProduct
+import org.apache.spark.mllib.linalg.Vectors
+
+// Create some vector data; also works for sparse vectors
+val dataFrame = sqlContext.createDataFrame(Seq(
+ ("a", Vectors.dense(1.0, 2.0, 3.0)),
+ ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector")
+
+val transformingVector = Vectors.dense(0.0, 1.0, 2.0)
+val transformer = new ElementwiseProduct()
+ .setScalingVec(transformingVector)
+ .setInputCol("vector")
+ .setOutputCol("transformedVector")
+
+// Batch transform the vectors to create new column:
+val transformedData = transformer.transform(dataFrame)
+
+{% endhighlight %}
+
+
+## VectorAssembler
+
+`VectorAssembler` is a transformer that combines a given list of columns into a single vector
+column.
+It is useful for combining raw features and features generated by different feature transformers
+into a single feature vector, in order to train ML models like logistic regression and decision
+trees.
+`VectorAssembler` accepts the following input column types: all numeric types, boolean type,
+and vector type.
+In each row, the values of the input columns will be concatenated into a vector in the specified
+order.
+
+**Examples**
+
+Assume that we have a DataFrame with the columns `id`, `hour`, `mobile`, `userFeatures`,
+and `clicked`:
+
+~~~
+ id | hour | mobile | userFeatures | clicked
+----|------|--------|------------------|---------
+ 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0
+~~~
+
+`userFeatures` is a vector column that contains three user features.
+We want to combine `hour`, `mobile`, and `userFeatures` into a single feature vector
+called `features` and use it to predict `clicked` or not.
+If we set `VectorAssembler`'s input columns to `hour`, `mobile`, and `userFeatures` and
+output column to `features`, after transformation we should get the following DataFrame:
+
+~~~
+ id | hour | mobile | userFeatures | clicked | features
+----|------|--------|------------------|---------|-----------------------------
+ 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 | [18.0, 1.0, 0.0, 10.0, 0.5]
+~~~
+
+
+
# Feature Selectors
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index c5f50ed7990f..4eb622d4b95e 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -207,7 +207,7 @@ val model1 = lr.fit(training.toDF)
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
-println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+println("Model 1 was fit using parameters: " + model1.parent.extractParamMap)
// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
@@ -222,7 +222,7 @@ val paramMapCombined = paramMap ++ paramMap2
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF, paramMapCombined)
-println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+println("Model 2 was fit using parameters: " + model2.parent.extractParamMap)
// Prepare test data.
val test = sc.parallelize(Seq(
@@ -289,7 +289,7 @@ LogisticRegressionModel model1 = lr.fit(training);
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
-System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
+System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
@@ -305,7 +305,7 @@ ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
-System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
+System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
List localTest = Lists.newArrayList(
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index f41ca70952eb..1b088969ddc2 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -47,7 +47,7 @@ Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasin
optimal *k* is usually one where there is an "elbow" in the WSSSE graph.
{% highlight scala %}
-import org.apache.spark.mllib.clustering.KMeans
+import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
@@ -62,6 +62,10 @@ val clusters = KMeans.train(parsedData, numClusters, numIterations)
// Evaluate clustering by computing Within Set Sum of Squared Errors
val WSSSE = clusters.computeCost(parsedData)
println("Within Set Sum of Squared Errors = " + WSSSE)
+
+// Save and load model
+clusters.save(sc, "myModelPath")
+val sameModel = KMeansModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -110,6 +114,10 @@ public class KMeansExample {
// Evaluate clustering by computing Within Set Sum of Squared Errors
double WSSSE = clusters.computeCost(parsedData.rdd());
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
+
+ // Save and load model
+ clusters.save(sc.sc(), "myModelPath");
+ KMeansModel sameModel = KMeansModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
@@ -124,7 +132,7 @@ Within Set Sum of Squared Error (WSSSE). You can reduce this error measure by in
fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph.
{% highlight python %}
-from pyspark.mllib.clustering import KMeans
+from pyspark.mllib.clustering import KMeans, KMeansModel
from numpy import array
from math import sqrt
@@ -143,6 +151,10 @@ def error(point):
WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y)
print("Within Set Sum of Squared Error = " + str(WSSSE))
+
+# Save and load model
+clusters.save(sc, "myModelPath")
+sameModel = KMeansModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -237,11 +249,11 @@ public class GaussianMixtureExample {
GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
// Save and load GaussianMixtureModel
- gmm.save(sc, "myGMMModel")
- GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
+ gmm.save(sc.sc(), "myGMMModel");
+ GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc.sc(), "myGMMModel");
// Output the parameters of the mixture model
for(int j=0; j
println(s"${a.id} -> ${a.cluster}")
}
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath")
{% endhighlight %}
A full example that produces the experiment described in the PIC paper can be found under
@@ -360,6 +376,10 @@ PowerIterationClusteringModel model = pic.run(similarities);
for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) {
System.out.println(a.id() + " -> " + a.cluster());
}
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index f723cd6b9dfa..4fe470a8de81 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th
import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.feature.Word2Vec
+import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
val input = sc.textFile("text8").map(line => line.split(" ").toSeq)
@@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40)
for((synonym, cosineSimilarity) <- synonyms) {
println(s"$synonym $cosineSimilarity")
}
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = Word2VecModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -410,6 +414,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.feature.ChiSqSelector
// Load some data in libsvm format
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
@@ -505,7 +510,7 @@ v_N
### Example
-This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value.
+This example below demonstrates how to transform vectors using a transforming vector value.
@@ -514,16 +519,44 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.feature.ElementwiseProduct
import org.apache.spark.mllib.linalg.Vectors
-// Load and parse the data:
-val data = sc.textFile("data/mllib/kmeans_data.txt")
-val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
+// Create some vector data; also works for sparse vectors
+val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0)))
val transformingVector = Vectors.dense(0.0, 1.0, 2.0)
val transformer = new ElementwiseProduct(transformingVector)
// Batch transform and per-row transform give the same results:
-val transformedData = transformer.transform(parsedData)
-val transformedData2 = parsedData.map(x => transformer.transform(x))
+val transformedData = transformer.transform(data)
+val transformedData2 = data.map(x => transformer.transform(x))
+
+{% endhighlight %}
+
+
+
+{% highlight java %}
+import java.util.Arrays;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.feature.ElementwiseProduct;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+// Create some vector data; also works for sparse vectors
+JavaRDD data = sc.parallelize(Arrays.asList(
+ Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0)));
+Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0);
+ElementwiseProduct transformer = new ElementwiseProduct(transformingVector);
+
+// Batch transform and per-row transform give the same results:
+JavaRDD transformedData = transformer.transform(data);
+JavaRDD transformedData2 = data.map(
+ new Function() {
+ @Override
+ public Vector call(Vector v) {
+ return transformer.transform(v);
+ }
+ }
+);
{% endhighlight %}
diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md
index 9fd9be0dd01b..bcc066a18552 100644
--- a/docs/mllib-frequent-pattern-mining.md
+++ b/docs/mllib-frequent-pattern-mining.md
@@ -39,11 +39,11 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters:
-[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the
+[`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the
FP-growth algorithm.
It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type.
Calling `FPGrowth.run` with transactions returns an
-[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html)
+[`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel)
that stores the frequent itemsets with their frequencies.
{% highlight scala %}
diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md
index b521c2f27cd6..5732bc4c7e79 100644
--- a/docs/mllib-isotonic-regression.md
+++ b/docs/mllib-isotonic-regression.md
@@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f
labels and real labels in the test set.
{% highlight scala %}
-import org.apache.spark.mllib.regression.IsotonicRegression
+import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel}
val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt")
@@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point =>
// Calculate mean squared error between predicted and real labels.
val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean()
println("Mean Squared Error = " + meanSquaredError)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = IsotonicRegressionModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map(
).rdd()).mean();
System.out.println("Mean Squared Error = " + meanSquaredError);
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 8029edca1600..3dc8cc902fa7 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -163,11 +163,8 @@ object, and make predictions with the resulting model to compute the training
error.
{% highlight scala %}
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
// Load training data in LIBSVM format.
@@ -231,15 +228,13 @@ calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given bellow:
{% highlight java %}
-import java.util.Random;
-
import scala.Tuple2;
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.*;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
-import org.apache.spark.mllib.linalg.Vector;
+
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;
@@ -282,8 +277,8 @@ public class SVMClassifier {
System.out.println("Area under ROC = " + auROC);
// Save and load model
- model.save(sc.sc(), "myModelPath");
- SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath");
+ model.save(sc, "myModelPath");
+ SVMModel sameModel = SVMModel.load(sc, "myModelPath");
}
}
{% endhighlight %}
@@ -315,15 +310,12 @@ a dependency.
-The following example shows how to load a sample dataset, build Logistic Regression model,
+The following example shows how to load a sample dataset, build SVM model,
and make predictions with the resulting model to compute the training error.
-Note that the Python API does not yet support model save/load but will in the future.
-
{% highlight python %}
-from pyspark.mllib.classification import LogisticRegressionWithSGD
+from pyspark.mllib.classification import SVMWithSGD, SVMModel
from pyspark.mllib.regression import LabeledPoint
-from numpy import array
# Load and parse the data
def parsePoint(line):
@@ -334,12 +326,16 @@ data = sc.textFile("data/mllib/sample_svm_data.txt")
parsedData = data.map(parsePoint)
# Build the model
-model = LogisticRegressionWithSGD.train(parsedData)
+model = SVMWithSGD.train(parsedData, iterations=100)
# Evaluating the model on training data
labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features)))
trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count())
print("Training Error = " + str(trainErr))
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = SVMModel.load(sc, "myModelPath")
{% endhighlight %}
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index acdcc371487f..bf6d124fd5d8 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -53,7 +53,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)
-val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial")
+val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial")
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
diff --git a/docs/monitoring.md b/docs/monitoring.md
index e75018499003..bcf885fe4e68 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -228,6 +228,14 @@ for a running application, at `http://localhost:4040/api/v1`.
/applications/[app-id]/storage/rdd/[rdd-id]
Details for the storage status of a given RDD
+
+
/applications/[app-id]/logs
+
Download the event logs for all attempts of the given application as a zip file
+
+
+
/applications/[app-id]/[attempt-id]/logs
+
Download the event logs for the specified attempt of the given application as a zip file
+
When running on Yarn, each application has multiple attempts, so `[app-id]` is actually
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 10f474f237bf..d5ff416fe89a 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -54,7 +54,7 @@ import org.apache.spark.SparkConf
-Spark {{site.SPARK_VERSION}} works with Java 6 and higher. If you are using Java 8, Spark supports
+Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java 8, Spark supports
[lambda expressions](http://docs.oracle.com/javase/tutorial/java/javaOO/lambdaexpressions.html)
for concisely writing functions, otherwise you can use the classes in the
[org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package.
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 9d55f435e80a..96cf612c54fd 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -242,6 +242,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
running against earlier versions, this property will be ignored.
+
+
spark.yarn.keytab
+
(none)
+
+ The full path to the file that contains the keytab for the principal specified above.
+ This keytab will be copied to the node running the Application Master via the Secure Distributed Cache,
+ for renewing the login tickets and the delegation tokens periodically.
+
+
+
+
spark.yarn.principal
+
(none)
+
+ Principal to be used to login to KDC, while running on secure HDFS.
+
+
# Launching Spark on YARN
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 0eed9adacf12..4f71fbc086cd 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by
Similarly, you can start one or more workers and connect them to the master via:
- ./sbin/start-slave.sh
+ ./sbin/start-slave.sh
Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default).
You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS).
@@ -77,7 +77,7 @@ Note, the master machine accesses each of the worker machines via ssh. By defaul
If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker.
-Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`:
+Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/sbin`:
- `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on.
- `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file.
diff --git a/docs/sparkr.md b/docs/sparkr.md
new file mode 100644
index 000000000000..4d82129921a3
--- /dev/null
+++ b/docs/sparkr.md
@@ -0,0 +1,223 @@
+---
+layout: global
+displayTitle: SparkR (R on Spark)
+title: SparkR (R on Spark)
+---
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+# Overview
+SparkR is an R package that provides a light-weight frontend to use Apache Spark from R.
+In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that
+supports operations like selection, filtering, aggregation etc. (similar to R data frames,
+[dplyr](https://github.com/hadley/dplyr)) but on large datasets.
+
+# SparkR DataFrames
+
+A DataFrame is a distributed collection of data organized into named columns. It is conceptually
+equivalent to a table in a relational database or a data frame in R, but with richer
+optimizations under the hood. DataFrames can be constructed from a wide array of sources such as:
+structured data files, tables in Hive, external databases, or existing local R data frames.
+
+All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell.
+
+## Starting Up: SparkContext, SQLContext
+
+
+The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster.
+You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name
+etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the
+SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should
+already be created for you.
+
+{% highlight r %}
+sc <- sparkR.init()
+sqlContext <- sparkRSQL.init(sc)
+{% endhighlight %}
+
+
+
+## Creating DataFrames
+With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources).
+
+### From local data frames
+The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R.
+
+
+{% highlight r %}
+df <- createDataFrame(sqlContext, faithful)
+
+# Displays the content of the DataFrame to stdout
+head(df)
+## eruptions waiting
+##1 3.600 79
+##2 1.800 54
+##3 3.333 74
+
+{% endhighlight %}
+
+
+### From Data Sources
+
+SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources.
+
+The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro).
+
+We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail.
+
+
+
+{% highlight r %}
+people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json")
+head(people)
+## age name
+##1 NA Michael
+##2 30 Andy
+##3 19 Justin
+
+# SparkR automatically infers the schema from the JSON file
+printSchema(people)
+# root
+# |-- age: integer (nullable = true)
+# |-- name: string (nullable = true)
+
+{% endhighlight %}
+
+
+The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example
+to a Parquet file using `write.df`
+
+
+
+### From Hive tables
+
+You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext).
+
+
+{% highlight r %}
+# sc is an existing SparkContext.
+hiveContext <- sparkRHive.init(sc)
+
+sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+
+# Queries can be expressed in HiveQL.
+results <- hiveContext.sql("FROM src SELECT key, value")
+
+# results is now a DataFrame
+head(results)
+## key value
+## 1 238 val_238
+## 2 86 val_86
+## 3 311 val_311
+
+{% endhighlight %}
+
+
+## DataFrame Operations
+
+SparkR DataFrames support a number of functions to do structured data processing.
+Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs:
+
+### Selecting rows, columns
+
+
+{% highlight r %}
+# Create the DataFrame
+df <- createDataFrame(sqlContext, faithful)
+
+# Get basic information about the DataFrame
+df
+## DataFrame[eruptions:double, waiting:double]
+
+# Select only the "eruptions" column
+head(select(df, df$eruptions))
+## eruptions
+##1 3.600
+##2 1.800
+##3 3.333
+
+# You can also pass in column name as strings
+head(select(df, "eruptions"))
+
+# Filter the DataFrame to only retain rows with wait times shorter than 50 mins
+head(filter(df, df$waiting < 50))
+## eruptions waiting
+##1 1.750 47
+##2 1.750 47
+##3 1.867 48
+
+{% endhighlight %}
+
+
+
+### Grouping, Aggregation
+
+SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
+
+
+{% highlight r %}
+
+# We use the `n` operator to count the number of times each waiting time appears
+head(summarize(groupBy(df, df$waiting), count = n(df$waiting)))
+## waiting count
+##1 81 13
+##2 60 6
+##3 68 1
+
+# We can also sort the output from the aggregation to get the most common waiting times
+waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting))
+head(arrange(waiting_counts, desc(waiting_counts$count)))
+
+## waiting count
+##1 78 15
+##2 83 14
+##3 81 13
+
+{% endhighlight %}
+
+
+### Operating on Columns
+
+SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions.
+
+
+{% highlight r %}
+
+# Convert waiting time from hours to seconds.
+# Note that we can assign this to a new column in the same DataFrame
+df$waiting_secs <- df$waiting * 60
+head(df)
+## eruptions waiting waiting_secs
+##1 3.600 79 4740
+##2 1.800 54 3240
+##3 3.333 74 4440
+
+{% endhighlight %}
+
+
+## Running SQL Queries from SparkR
+A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data.
+The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`.
+
+
+{% highlight r %}
+# Load a JSON file
+people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json")
+
+# Register this DataFrame as a table.
+registerTempTable(people, "people")
+
+# SQL statements can be run by using the sql method
+teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19")
+head(teenagers)
+## name
+##1 Justin
+
+{% endhighlight %}
+
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index ab646f65bb5e..61f9c5f02ac7 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -11,6 +11,7 @@ title: Spark SQL and DataFrames
Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine.
+For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section.
# DataFrames
@@ -108,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO
val sc: SparkContext // An existing SparkContext.
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
-val df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+val df = sqlContext.read.json("examples/src/main/resources/people.json")
// Displays the content of the DataFrame to stdout
df.show()
@@ -121,7 +122,7 @@ df.show()
JavaSparkContext sc = ...; // An existing JavaSparkContext.
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
-DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json");
+DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json");
// Displays the content of the DataFrame to stdout
df.show();
@@ -134,7 +135,7 @@ df.show();
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
-df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+df = sqlContext.read.json("examples/src/main/resources/people.json")
# Displays the content of the DataFrame to stdout
df.show()
@@ -170,7 +171,7 @@ val sc: SparkContext // An existing SparkContext.
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
// Create the DataFrame
-val df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+val df = sqlContext.read.json("examples/src/main/resources/people.json")
// Show the content of the DataFrame
df.show()
@@ -220,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext.
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc)
// Create the DataFrame
-DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json");
+DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json");
// Show the content of the DataFrame
df.show();
@@ -276,7 +277,7 @@ from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
# Create the DataFrame
-df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+df = sqlContext.read.json("examples/src/main/resources/people.json")
# Show the content of the DataFrame
df.show()
@@ -776,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
Ignore mode means that when saving a DataFrame to a data source, if data already exists,
the save operation is expected to not save the contents of the DataFrame and to not
- change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL.
+ change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
@@ -946,11 +947,11 @@ import sqlContext.implicits._
val people: RDD[Person] = ... // An RDD of case class objects, from the previous example.
// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet.
-people.saveAsParquetFile("people.parquet")
+people.write.parquet("people.parquet")
// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
// The result of loading a Parquet file is also a DataFrame.
-val parquetFile = sqlContext.parquetFile("people.parquet")
+val parquetFile = sqlContext.read.parquet("people.parquet")
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile")
@@ -968,11 +969,11 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
DataFrame schemaPeople = ... // The DataFrame from the previous example.
// DataFrames can be saved as Parquet files, maintaining the schema information.
-schemaPeople.saveAsParquetFile("people.parquet");
+schemaPeople.write().parquet("people.parquet");
// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
// The result of loading a parquet file is also a DataFrame.
-DataFrame parquetFile = sqlContext.parquetFile("people.parquet");
+DataFrame parquetFile = sqlContext.read().parquet("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
@@ -994,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function()
schemaPeople # The DataFrame from the previous example.
# DataFrames can be saved as Parquet files, maintaining the schema information.
-schemaPeople.saveAsParquetFile("people.parquet")
+schemaPeople.write.parquet("people.parquet")
# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
# The result of loading a parquet file is also a DataFrame.
-parquetFile = sqlContext.parquetFile("people.parquet")
+parquetFile = sqlContext.read.parquet("people.parquet")
# Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
@@ -1030,7 +1031,7 @@ teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND a
teenNames <- map(teenagers, function(p) { paste("Name:", p$name)})
for (teenName in collect(teenNames)) {
cat(teenName, "\n")
-}
+}
{% endhighlight %}
@@ -1086,9 +1087,9 @@ path
{% endhighlight %}
-By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will
-automatically extract the partitioning information from the paths. Now the schema of the returned
-DataFrame becomes:
+By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL
+will automatically extract the partitioning information from the paths.
+Now the schema of the returned DataFrame becomes:
{% highlight text %}
@@ -1101,7 +1102,11 @@ root
{% endhighlight %}
Notice that the data types of the partitioning columns are automatically inferred. Currently,
-numeric data types and string type are supported.
+numeric data types and string type are supported. Sometimes users may not want to automatically
+infer the data types of the partitioning columns. For these use cases, the automatic type inference
+can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to
+`true`. When type inference is disabled, string type will be used for the partitioning columns.
+
### Schema merging
@@ -1121,15 +1126,15 @@ import sqlContext.implicits._
// Create a simple DataFrame, stored into a partition directory
val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double")
-df1.saveAsParquetFile("data/test_table/key=1")
+df1.write.parquet("data/test_table/key=1")
// Create another DataFrame in a new partition directory,
// adding a new column and dropping an existing column
val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple")
-df2.saveAsParquetFile("data/test_table/key=2")
+df2.write.parquet("data/test_table/key=2")
// Read the partitioned table
-val df3 = sqlContext.parquetFile("data/test_table")
+val df3 = sqlContext.read.parquet("data/test_table")
df3.printSchema()
// The final schema consists of all 3 columns in the Parquet files together
@@ -1268,12 +1273,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
+This conversion can be done using `SQLContext.read.json()` on either an RDD of String,
+or a JSON file.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1284,8 +1287,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
// A JSON dataset is pointed to by path.
// The path can be either a single text file or a directory storing text files.
val path = "examples/src/main/resources/people.json"
-// Create a DataFrame from the file(s) pointed to by path
-val people = sqlContext.jsonFile(path)
+val people = sqlContext.read.json(path)
// The inferred schema can be visualized using the printSchema() method.
people.printSchema()
@@ -1303,19 +1305,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age
// an RDD[String] storing one JSON object per string.
val anotherPeopleRDD = sc.parallelize(
"""{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil)
-val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
+val anotherPeople = sqlContext.read.json(anotherPeopleRDD)
{% endhighlight %}
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext` :
+This conversion can be done using `SQLContext.read().json()` on either an RDD of String,
+or a JSON file.
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
-
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1325,9 +1325,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
// A JSON dataset is pointed to by path.
// The path can be either a single text file or a directory storing text files.
-String path = "examples/src/main/resources/people.json";
-// Create a DataFrame from the file(s) pointed to by path
-DataFrame people = sqlContext.jsonFile(path);
+DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json");
// The inferred schema can be visualized using the printSchema() method.
people.printSchema();
@@ -1346,18 +1344,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN
List jsonData = Arrays.asList(
"{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}");
JavaRDD anotherPeopleRDD = sc.parallelize(jsonData);
-DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD);
+DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD);
{% endhighlight %}
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
+This conversion can be done using `SQLContext.read.json` on a JSON file.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1368,9 +1363,7 @@ sqlContext = SQLContext(sc)
# A JSON dataset is pointed to by path.
# The path can be either a single text file or a directory storing text files.
-path = "examples/src/main/resources/people.json"
-# Create a DataFrame from the file(s) pointed to by path
-people = sqlContext.jsonFile(path)
+people = sqlContext.read.json("examples/src/main/resources/people.json")
# The inferred schema can be visualized using the printSchema() method.
people.printSchema()
@@ -1393,12 +1386,11 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using
+the `jsonFile` function, which loads data from a directory of JSON files where each line of the
+files is a JSON object.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1487,7 +1479,7 @@ expressed in HiveQL.
{% highlight java %}
// sc is an existing JavaSparkContext.
-HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc);
+HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc.sc);
sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)");
sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src");
@@ -1502,7 +1494,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and
-adds support for finding tables in the MetaStore and writing queries using HiveQL.
+adds support for finding tables in the MetaStore and writing queries using HiveQL.
{% highlight python %}
# sc is an existing SparkContext.
from pyspark.sql import HiveContext
@@ -1526,8 +1518,8 @@ adds support for finding tables in the MetaStore and writing queries using HiveQ
# sc is an existing SparkContext.
sqlContext <- sparkRHive.init(sc)
-hql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
-hql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
# Queries can be expressed in HiveQL.
results = sqlContext.sql("FROM src SELECT key, value").collect()
@@ -1537,6 +1529,70 @@ results = sqlContext.sql("FROM src SELECT key, value").collect()
+### Interacting with Different Versions of Hive Metastore
+
+One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore,
+which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below.
+
+Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET`
+and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive
+jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the
+version specified by users. An isolated classloader is used here to avoid dependency conflicts.
+
+
+
Property Name
Default
Meaning
+
+
spark.sql.hive.metastore.version
+
0.13.1
+
+ Version of the Hive metastore. Available
+ options are 0.12.0 and 0.13.1. Support for more versions is coming in the future.
+
+
+
+
spark.sql.hive.metastore.jars
+
builtin
+
+ Location of the jars that should be used to instantiate the HiveMetastoreClient. This
+ property can be one of three options:
+
+
builtin
+ Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is
+ enabled. When this option is chosen, spark.sql.hive.metastore.version must be
+ either 0.13.1 or not defined.
+
maven
+ Use Hive jars of specified version downloaded from Maven repositories.
+
A classpath in the standard format for both Hive and Hadoop.
+ A comma separated list of class prefixes that should be loaded using the classloader that is
+ shared between Spark SQL and a specific version of Hive. An example of classes that should
+ be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need
+ to be shared are those that interact with classes that are already shared. For example,
+ custom appenders that are used by log4j.
+
+
+
+
+
spark.sql.hive.metastore.barrierPrefixes
+
(empty)
+
+
+ A comma separated list of class prefixes that should explicitly be reloaded for each version
+ of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a
+ prefix that typically would be shared (i.e. org.apache.spark.*).
+
+
+
+
+
+
## JDBC To Other Databases
Spark SQL also includes a data source that can read data from other databases using JDBC. This
@@ -1570,7 +1626,7 @@ the Data Sources API. The following options are supported:
dbtable
- The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of
+ The JDBC table that should be read. Note that anything that is valid in a FROM clause of
a SQL query can be used. For example, instead of a full table you could also use a
subquery in parentheses.
@@ -1714,7 +1770,7 @@ that these options will be deprecated in future release as more optimizations ar
Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
statistics are only supported for Hive Metastore tables where the command
- `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run.
+ ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
@@ -1733,11 +1789,20 @@ that these options will be deprecated in future release as more optimizations ar
Configures the number of partitions to use when shuffling data for joins or aggregations.
+
+
spark.sql.planner.externalSort
+
false
+
+ When true, performs sorts spilling to disk as needed otherwise sort each partition in memory.
+
+
# Distributed SQL Engine
-Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code.
+Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface.
+In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries,
+without the need to write any code.
## Running the Thrift JDBC/ODBC server
@@ -1751,7 +1816,7 @@ To start the JDBC/ODBC server, run the following in the Spark directory:
This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to
specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of
all available options. By default, the server listens on localhost:10000. You may override this
-bahaviour via either environment variables, i.e.:
+behaviour via either environment variables, i.e.:
{% highlight bash %}
export HIVE_SERVER2_THRIFT_PORT=
@@ -1816,6 +1881,25 @@ options.
## Upgrading from Spark SQL 1.3 to 1.4
+#### DataFrame data reader/writer interface
+
+Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`)
+and writing data out (`DataFrame.write`),
+and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`).
+
+See the API docs for `SQLContext.read` (
+ Scala,
+ Java,
+ Python
+) and `DataFrame.write` (
+ Scala,
+ Java,
+ Python
+) more information.
+
+
+#### DataFrame.groupBy retains grouping columns
+
Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md
index 6a2048121f8b..a75587a92adc 100644
--- a/docs/streaming-custom-receivers.md
+++ b/docs/streaming-custom-receivers.md
@@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers
---
Spark Streaming can receive streaming data from any arbitrary data source beyond
-the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.).
+the ones for which it has built-in support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.).
This requires the developer to implement a *receiver* that is customized for receiving data from
the concerned data source. This guide walks through the process of implementing a custom receiver
and using it in a Spark Streaming application. Note that custom receivers can be implemented
@@ -21,15 +21,15 @@ A custom receiver must extend this abstract class by implementing two methods
- `onStop()`: Things to do to stop receiving data.
Both `onStart()` and `onStop()` must not block indefinitely. Typically, `onStart()` would start the threads
-that responsible for receiving the data and `onStop()` would ensure that the receiving by those threads
+that are responsible for receiving the data, and `onStop()` would ensure that these threads receiving the data
are stopped. The receiving threads can also use `isStopped()`, a `Receiver` method, to check whether they
should stop receiving data.
Once the data is received, that data can be stored inside Spark
by calling `store(data)`, which is a method provided by the Receiver class.
-There are number of flavours of `store()` which allow you store the received data
-record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavour of
-`store()` used to implemented a receiver affects its reliability and fault-tolerance semantics.
+There are a number of flavors of `store()` which allow one to store the received data
+record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavor of
+`store()` used to implement a receiver affects its reliability and fault-tolerance semantics.
This is discussed [later](#receiver-reliability) in more detail.
Any exception in the receiving threads should be caught and handled properly to avoid silent
@@ -60,7 +60,7 @@ class CustomReceiver(host: String, port: Int)
def onStop() {
// There is nothing much to do as the thread calling receive()
- // is designed to stop by itself isStopped() returns false
+ // is designed to stop by itself if isStopped() returns false
}
/** Create a socket connection and receive data until receiver is stopped */
@@ -123,7 +123,7 @@ public class JavaCustomReceiver extends Receiver {
public void onStop() {
// There is nothing much to do as the thread calling receive()
- // is designed to stop by itself isStopped() returns false
+ // is designed to stop by itself if isStopped() returns false
}
/** Create a socket connection and receive data until receiver is stopped */
@@ -167,7 +167,7 @@ public class JavaCustomReceiver extends Receiver {
The custom receiver can be used in a Spark Streaming application by using
`streamingContext.receiverStream()`. This will create
-input DStream using data received by the instance of custom receiver, as shown below
+an input DStream using data received by the instance of custom receiver, as shown below:
@@ -206,22 +206,20 @@ there are two kinds of receivers based on their reliability and fault-tolerance
and stored in Spark reliably (that is, replicated successfully). Usually,
implementing this receiver involves careful consideration of the semantics of source
acknowledgements.
-1. *Unreliable Receiver* - These are receivers for unreliable sources that do not support
- acknowledging. Even for reliable sources, one may implement an unreliable receiver that
- do not go into the complexity of acknowledging correctly.
+1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgement to a source. This can be used for sources that do not support acknowledgement, or even for reliable sources when one does not want or need to go into the complexity of acknowledgement.
To implement a *reliable receiver*, you have to use `store(multiple-records)` to store data.
-This flavour of `store` is a blocking call which returns only after all the given records have
+This flavor of `store` is a blocking call which returns only after all the given records have
been stored inside Spark. If the receiver's configured storage level uses replication
(enabled by default), then this call returns after replication has completed.
Thus it ensures that the data is reliably stored, and the receiver can now acknowledge the
-source appropriately. This ensures that no data is caused when the receiver fails in the middle
+source appropriately. This ensures that no data is lost when the receiver fails in the middle
of replicating data -- the buffered data will not be acknowledged and hence will be later resent
by the source.
An *unreliable receiver* does not have to implement any of this logic. It can simply receive
records from the source and insert them one-at-a-time using `store(single-record)`. While it does
-not get the reliability guarantees of `store(multiple-records)`, it has the following advantages.
+not get the reliability guarantees of `store(multiple-records)`, it has the following advantages:
- The system takes care of chunking that data into appropriate sized blocks (look for block
interval in the [Spark Streaming Programming Guide](streaming-programming-guide.html)).
diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md
index 64714f0b799f..02bc95d0e95f 100644
--- a/docs/streaming-kafka-integration.md
+++ b/docs/streaming-kafka-integration.md
@@ -7,7 +7,7 @@ title: Spark Streaming + Kafka Integration Guide
## Approach 1: Receiver-based Approach
This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data.
-However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs.
+However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs.
Next, we discuss how to use this approach in your streaming application.
@@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application.
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume])
You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*;
@@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application.
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]);
You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
@@ -105,7 +105,7 @@ Next, we discuss how to use this approach in your streaming application.
streamingContext, [map of Kafka parameters], [set of topics to consume])
See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*;
@@ -116,8 +116,15 @@ Next, we discuss how to use this approach in your streaming application.
[map of Kafka parameters], [set of topics to consume]);
See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
+
+
+ from pyspark.streaming.kafka import KafkaUtils
+ directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers})
+
+ By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils)
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py).
@@ -147,10 +154,13 @@ Next, we discuss how to use this approach in your streaming application.
}
);
+
+ Not supported
+
You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application.
Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API.
-3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation.
\ No newline at end of file
+3. **Deploying:** This is same as the first approach, for Scala, Java and Python.
diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md
index 379eb513d521..aa9749afbc86 100644
--- a/docs/streaming-kinesis-integration.md
+++ b/docs/streaming-kinesis-integration.md
@@ -32,7 +32,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
val kinesisStream = KinesisUtils.createStream(
- streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position])
+ streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL],
+ [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2)
See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$)
and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example.
@@ -44,7 +45,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream(
- streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]);
+ streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL],
+ [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2);
See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html)
and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example.
@@ -54,19 +56,23 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
- `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream
- - `[Kinesis stream name]`: The Kinesis stream that this streaming application receives from
- - The application name used in the streaming context becomes the Kinesis application name
+ - `[Kineiss app name]`: The application name that will be used to checkpoint the Kinesis
+ sequence numbers in DynamoDB table.
- The application name must be unique for a given account and region.
- - The Kinesis backend automatically associates the application name to the Kinesis stream using a DynamoDB table (always in the us-east-1 region) created during Kinesis Client Library initialization.
- - Changing the application name or stream name can lead to Kinesis errors in some cases. If you see errors, you may need to manually delete the DynamoDB table.
+ - If the table exists but has incorrect checkpoint information (for a different stream, or
+ old expired sequenced numbers), then there may be temporary errors.
+ - `[Kinesis stream name]`: The Kinesis stream that this streaming application will pull data from.
- `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region).
+ - `[region name]`: Valid Kinesis region names can be found [here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html).
+
- `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application.
- `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details).
+ In other versions of the API, you can also specify the AWS access key and secret key directly.
3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide).
@@ -122,12 +128,12 @@ To run the example,
@@ -136,7 +142,7 @@ To run the example,
- To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer.
- bin/run-example streaming.KinesisWordCountProducerASL [Kinesis stream name] [endpoint URL] 1000 10
+ bin/run-example streaming.KinesisWordProducerASL [Kinesis stream name] [endpoint URL] 1000 10
This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index bd863d48d53e..1eb3b30332e4 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -11,7 +11,7 @@ description: Spark Streaming programming guide and tutorial for Spark SPARK_VERS
# Overview
Spark Streaming is an extension of the core Spark API that enables scalable, high-throughput,
fault-tolerant stream processing of live data streams. Data can be ingested from many sources
-like Kafka, Flume, Twitter, ZeroMQ, Kinesis or TCP sockets can be processed using complex
+like Kafka, Flume, Twitter, ZeroMQ, Kinesis, or TCP sockets, and can be processed using complex
algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`.
Finally, processed data can be pushed out to filesystems, databases,
and live dashboards. In fact, you can apply Spark's
@@ -52,7 +52,7 @@ different languages.
**Note:** Python API for Spark Streaming has been introduced in Spark 1.2. It has all the DStream
transformations and almost all the output operations available in Scala and Java interfaces.
-However, it has only support for basic sources like text files and text data over sockets.
+However, it only has support for basic sources like text files and text data over sockets.
APIs for additional sources, like Kafka and Flume, will be available in the future.
Further information about available features in the Python API are mentioned throughout this
document; look out for the tag
@@ -69,15 +69,15 @@ do is as follows.
-First, we import the names of the Spark Streaming classes, and some implicit
-conversions from StreamingContext into our environment, to add useful methods to
+First, we import the names of the Spark Streaming classes and some implicit
+conversions from StreamingContext into our environment in order to add useful methods to
other classes we need (like DStream). [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the
-main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second.
+main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and a batch interval of 1 second.
{% highlight scala %}
import org.apache.spark._
import org.apache.spark.streaming._
-import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+
+import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3
// Create a local StreamingContext with two working thread and batch interval of 1 second.
// The master requires 2 cores to prevent from a starvation scenario.
@@ -96,7 +96,7 @@ val lines = ssc.socketTextStream("localhost", 9999)
This `lines` DStream represents the stream of data that will be received from the data
server. Each record in this DStream is a line of text. Next, we want to split the lines by
-space into words.
+space characters into words.
{% highlight scala %}
// Split each line into words
@@ -109,7 +109,7 @@ each line will be split into multiple words and the stream of words is represent
`words` DStream. Next, we want to count these words.
{% highlight scala %}
-import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+
+import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3
// Count each word in each batch
val pairs = words.map(word => (word, 1))
val wordCounts = pairs.reduceByKey(_ + _)
@@ -463,7 +463,7 @@ receive it there. However, for local testing and unit tests, you can pass "local
in-process (detects the number of cores in the local system). Note that this internally creates a [SparkContext](api/scala/index.html#org.apache.spark.SparkContext) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`.
The batch interval must be set based on the latency requirements of your application
-and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size)
+and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval)
section for more details.
A `StreamingContext` object can also be created from an existing `SparkContext` object.
@@ -498,7 +498,7 @@ receive it there. However, for local testing and unit tests, you can pass "local
in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`.
The batch interval must be set based on the latency requirements of your application
-and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size)
+and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval)
section for more details.
A `JavaStreamingContext` object can also be created from an existing `JavaSparkContext`.
@@ -531,7 +531,7 @@ receive it there. However, for local testing and unit tests, you can pass "local
in-process (detects the number of cores in the local system).
The batch interval must be set based on the latency requirements of your application
-and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size)
+and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval)
section for more details.
@@ -549,7 +549,7 @@ After a context is defined, you have to do the following.
- Once a context has been started, no new streaming computations can be set up or added to it.
- Once a context has been stopped, it cannot be restarted.
- Only one StreamingContext can be active in a JVM at the same time.
-- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set optional parameter of `stop()` called `stopSparkContext` to false.
+- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set the optional parameter of `stop()` called `stopSparkContext` to false.
- A SparkContext can be re-used to create multiple StreamingContexts, as long as the previous StreamingContext is stopped (without stopping the SparkContext) before the next StreamingContext is created.
***
@@ -583,7 +583,7 @@ the `flatMap` operation is applied on each RDD in the `lines` DStream to generat
These underlying RDD transformations are computed by the Spark engine. The DStream operations
-hide most of these details and provide the developer with higher-level API for convenience.
+hide most of these details and provide the developer with a higher-level API for convenience.
These operations are discussed in detail in later sections.
***
@@ -600,7 +600,7 @@ data from a source and stores it in Spark's memory for processing.
Spark Streaming provides two categories of built-in streaming sources.
- *Basic sources*: Sources directly available in the StreamingContext API.
- Example: file systems, socket connections, and Akka actors.
+ Examples: file systems, socket connections, and Akka actors.
- *Advanced sources*: Sources like Kafka, Flume, Kinesis, Twitter, etc. are available through
extra utility classes. These require linking against extra dependencies as discussed in the
[linking](#linking) section.
@@ -610,11 +610,11 @@ We are going to discuss some of the sources present in each category later in th
Note that, if you want to receive multiple streams of data in parallel in your streaming
application, you can create multiple input DStreams (discussed
further in the [Performance Tuning](#level-of-parallelism-in-data-receiving) section). This will
-create multiple receivers which will simultaneously receive multiple data streams. But note that
-Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the
-Spark Streaming application. Hence, it is important to remember that Spark Streaming application
+create multiple receivers which will simultaneously receive multiple data streams. But note that a
+Spark worker/executor is a long-running task, hence it occupies one of the cores allocated to the
+Spark Streaming application. Therefore, it is important to remember that a Spark Streaming application
needs to be allocated enough cores (or threads, if running locally) to process the received data,
-as well as, to run the receiver(s).
+as well as to run the receiver(s).
##### Points to remember
{:.no_toc}
@@ -623,13 +623,13 @@ as well as, to run the receiver(s).
Either of these means that only one thread will be used for running tasks locally. If you are using
a input DStream based on a receiver (e.g. sockets, Kafka, Flume, etc.), then the single thread will
be used to run the receiver, leaving no thread for processing the received data. Hence, when
- running locally, always use "local[*n*]" as the master URL where *n* > number of receivers to run
- (see [Spark Properties](configuration.html#spark-properties.html) for information on how to set
+ running locally, always use "local[*n*]" as the master URL, where *n* > number of receivers to run
+ (see [Spark Properties](configuration.html#spark-properties) for information on how to set
the master).
- Extending the logic to running on a cluster, the number of cores allocated to the Spark Streaming
- application must be more than the number of receivers. Otherwise the system will receive data, but
- not be able to process them.
+ application must be more than the number of receivers. Otherwise the system will receive data, but
+ not be able to process it.
### Basic Sources
{:.no_toc}
@@ -639,7 +639,7 @@ which creates a DStream from text
data received over a TCP socket connection. Besides sockets, the StreamingContext API provides
methods for creating DStreams from files and Akka actors as input sources.
-- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as
+- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as:
@@ -682,14 +682,14 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea
### Advanced Sources
{:.no_toc}
-Python API As of Spark 1.3,
+Python API As of Spark {{site.SPARK_VERSION_SHORT}},
out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future.
This category of sources require interfacing with external non-Spark libraries, some of them with
complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts
-of dependencies, the functionality to create DStreams from these sources have been moved to separate
-libraries, that can be [linked](#linking) to explicitly when necessary. For example, if you want to
-create a DStream using data from Twitter's stream of tweets, you have to do the following.
+of dependencies, the functionality to create DStreams from these sources has been moved to separate
+libraries that can be [linked](#linking) to explicitly when necessary. For example, if you want to
+create a DStream using data from Twitter's stream of tweets, you have to do the following:
1. *Linking*: Add the artifact `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` to the
SBT/Maven project dependencies.
@@ -719,11 +719,11 @@ TwitterUtils.createStream(jssc);
Note that these advanced sources are not available in the Spark shell, hence applications based on
these advanced sources cannot be tested in the shell. If you really want to use them in the Spark
shell you will have to download the corresponding Maven artifact's JAR along with its dependencies
-and it in the classpath.
+and add it to the classpath.
Some of these advanced sources are as follows.
-- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.1.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details.
+- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details.
- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details.
@@ -743,7 +743,7 @@ Some of these advanced sources are as follows.
Python API This is not yet supported in Python.
-Input DStreams can also be created out of custom data sources. All you have to do is implement an
+Input DStreams can also be created out of custom data sources. All you have to do is implement a
user-defined **receiver** (see next section to understand what that is) that can receive data from
the custom sources and push it into Spark. See the [Custom Receiver
Guide](streaming-custom-receivers.html) for details.
@@ -753,14 +753,12 @@ Guide](streaming-custom-receivers.html) for details.
There can be two kinds of data sources based on their *reliability*. Sources
(like Kafka and Flume) allow the transferred data to be acknowledged. If the system receiving
-data from these *reliable* sources acknowledge the received data correctly, it can be ensured
-that no data gets lost due to any kind of failure. This leads to two kinds of receivers.
+data from these *reliable* sources acknowledges the received data correctly, it can be ensured
+that no data will be lost due to any kind of failure. This leads to two kinds of receivers:
-1. *Reliable Receiver* - A *reliable receiver* correctly acknowledges a reliable
- source that the data has been received and stored in Spark with replication.
-1. *Unreliable Receiver* - These are receivers for sources that do not support acknowledging. Even
- for reliable sources, one may implement an unreliable receiver that do not go into the complexity
- of acknowledging correctly.
+1. *Reliable Receiver* - A *reliable receiver* correctly sends acknowledgment to a reliable
+ source when the data has been received and stored in Spark with replication.
+1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgment to a source. This can be used for sources that do not support acknowledgment, or even for reliable sources when one does not want or need to go into the complexity of acknowledgment.
The details of how to write a reliable receiver are discussed in the
[Custom Receiver Guide](streaming-custom-receivers.html).
@@ -828,7 +826,7 @@ Some of the common ones are as follows.
cogroup(otherStream, [numTasks])
-
When called on DStream of (K, V) and (K, W) pairs, return a new DStream of
+
When called on a DStream of (K, V) and (K, W) pairs, return a new DStream of
(K, Seq[V], Seq[W]) tuples.
@@ -852,13 +850,13 @@ A few of these transformations are worth discussing in more detail.
The `updateStateByKey` operation allows you to maintain arbitrary state while continuously updating
it with new information. To use this, you will have to do two steps.
-1. Define the state - The state can be of arbitrary data type.
+1. Define the state - The state can be an arbitrary data type.
1. Define the state update function - Specify with a function how to update the state using the
-previous state and the new values from input stream.
+previous state and the new values from an input stream.
Let's illustrate this with an example. Say you want to maintain a running count of each word
seen in a text data stream. Here, the running count is the state and it is an integer. We
-define the update function as
+define the update function as:
@@ -947,7 +945,7 @@ operation that is not exposed in the DStream API.
For example, the functionality of joining every batch in a data stream
with another dataset is not directly exposed in the DStream API. However,
you can easily use `transform` to do this. This enables very powerful possibilities. For example,
-if you want to do real-time data cleaning by joining the input data stream with precomputed
+one can do real-time data cleaning by joining the input data stream with precomputed
spam information (maybe generated with Spark as well) and then filtering based on it.
-In fact, you can also use [machine learning](mllib-guide.html) and
-[graph computation](graphx-programming-guide.html) algorithms in the `transform` method.
+Note that the supplied function gets called in every batch interval. This allows you to do
+time-varying RDD operations, that is, RDD operations, number of partitions, broadcast variables,
+etc. can be changed between batches.
#### Window Operations
{:.no_toc}
Spark Streaming also provides *windowed computations*, which allow you to apply
-transformations over a sliding window of data. This following figure illustrates this sliding
+transformations over a sliding window of data. The following figure illustrates this sliding
window.
@@ -1009,11 +1008,11 @@ window.
As shown in the figure, every time the window *slides* over a source DStream,
the source RDDs that fall within the window are combined and operated upon to produce the
-RDDs of the windowed DStream. In this specific case, the operation is applied over last 3 time
+RDDs of the windowed DStream. In this specific case, the operation is applied over the last 3 time
units of data, and slides by 2 time units. This shows that any window operation needs to
specify two parameters.
- * window length - The duration of the window (3 in the figure)
+ * window length - The duration of the window (3 in the figure).
* sliding interval - The interval at which the window operation is performed (2 in
the figure).
@@ -1021,7 +1020,7 @@ These two parameters must be multiples of the batch interval of the source DStre
figure).
Let's illustrate the window operations with an example. Say, you want to extend the
-[earlier example](#a-quick-example) by generating word counts over last 30 seconds of data,
+[earlier example](#a-quick-example) by generating word counts over the last 30 seconds of data,
every 10 seconds. To do this, we have to apply the `reduceByKey` operation on the `pairs` DStream of
`(word, 1)` pairs over the last 30 seconds of data. This is done using the
operation `reduceByKeyAndWindow`.
@@ -1096,13 +1095,13 @@ said two parameters - windowLength and slideInterval.
A more efficient version of the above reduceByKeyAndWindow() where the reduce
+
A more efficient version of the above reduceByKeyAndWindow() where the reduce
value of each window is calculated incrementally using the reduce values of the previous window.
- This is done by reducing the new data that enter the sliding window, and "inverse reducing" the
- old data that leave the window. An example would be that of "adding" and "subtracting" counts
- of keys as the window slides. However, it is applicable to only "invertible reduce functions",
+ This is done by reducing the new data that enters the sliding window, and "inverse reducing" the
+ old data that leaves the window. An example would be that of "adding" and "subtracting" counts
+ of keys as the window slides. However, it is applicable only to "invertible reduce functions",
that is, those reduce functions which have a corresponding "inverse reduce" function (taken as
- parameter invFunc. Like in reduceByKeyAndWindow, the number of reduce tasks
+ parameter invFunc). Like in reduceByKeyAndWindow, the number of reduce tasks
is configurable through an optional argument. Note that [checkpointing](#checkpointing) must be
enabled for using this operation.
@@ -1224,7 +1223,7 @@ For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.stre
***
## Output Operations on DStreams
-Output operations allow DStream's data to be pushed out external systems like a database or a file systems.
+Output operations allow DStream's data to be pushed out to external systems like a database or a file systems.
Since the output operations actually allow the transformed data to be consumed by external systems,
they trigger the actual execution of all the DStream transformations (similar to actions for RDDs).
Currently, the following output operations are defined:
@@ -1233,7 +1232,7 @@ Currently, the following output operations are defined:
Output Operation
Meaning
print()
-
Prints first ten elements of every batch of data in a DStream on the driver node running
+
Prints the first ten elements of every batch of data in a DStream on the driver node running
the streaming application. This is useful for development and debugging.
Python API This is called
@@ -1242,12 +1241,12 @@ Currently, the following output operations are defined:
saveAsTextFiles(prefix, [suffix])
-
Save this DStream's contents as a text files. The file name at each batch interval is
+
Save this DStream's contents as text files. The file name at each batch interval is
generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
saveAsObjectFiles(prefix, [suffix])
-
Save this DStream's contents as a SequenceFile of serialized Java objects. The file
+
Save this DStream's contents as SequenceFiles of serialized Java objects. The file
name at each batch interval is generated based on prefix and
suffix: "prefix-TIME_IN_MS[.suffix]".
@@ -1257,7 +1256,7 @@ Currently, the following output operations are defined:
saveAsHadoopFiles(prefix, [suffix])
-
Save this DStream's contents as a Hadoop file. The file name at each batch interval is
+
Save this DStream's contents as Hadoop files. The file name at each batch interval is
generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
Python API This is not available in
@@ -1267,7 +1266,7 @@ Currently, the following output operations are defined:
foreachRDD(func)
The most generic output operator that applies a function, func, to each RDD generated from
- the stream. This function should push the data in each RDD to a external system, like saving the RDD to
+ the stream. This function should push the data in each RDD to an external system, such as saving the RDD to
files, or writing it over the network to a database. Note that the function func is executed
in the driver process running the streaming application, and will usually have RDD actions in it
that will force the computation of the streaming RDDs.
@@ -1277,14 +1276,14 @@ Currently, the following output operations are defined:
### Design Patterns for using foreachRDD
{:.no_toc}
-`dstream.foreachRDD` is a powerful primitive that allows data to sent out to external systems.
+`dstream.foreachRDD` is a powerful primitive that allows data to be sent out to external systems.
However, it is important to understand how to use this primitive correctly and efficiently.
Some of the common mistakes to avoid are as follows.
Often writing data to external system requires creating a connection object
(e.g. TCP connection to a remote server) and using it to send data to a remote system.
For this purpose, a developer may inadvertently try creating a connection object at
-the Spark driver, but try to use it in a Spark worker to save records in the RDDs.
+the Spark driver, and then try to use it in a Spark worker to save records in the RDDs.
For example (in Scala),
@@ -1346,7 +1345,7 @@ dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord))
Typically, creating a connection object has time and resource overheads. Therefore, creating and
destroying a connection object for each record can incur unnecessarily high overheads and can
significantly reduce the overall throughput of the system. A better solution is to use
-`rdd.foreachPartition` - create a single connection object and send all the records in a RDD
+`rdd.foreachPartition` - create a single connection object and send all the records in a RDD
partition using that connection.
@@ -1427,26 +1426,6 @@ You can easily use [DataFrames and SQL](sql-programming-guide.html) operations o
{% highlight scala %}
-/** Lazily instantiated singleton instance of SQLContext */
-object SQLContextSingleton {
- @transient private var instance: SQLContext = null
-
- // Instantiate SQLContext on demand
- def getInstance(sparkContext: SparkContext): SQLContext = synchronized {
- if (instance == null) {
- instance = new SQLContext(sparkContext)
- }
- instance
- }
-}
-
-...
-
-/** Case class for converting RDD to DataFrame */
-case class Row(word: String)
-
-...
-
/** DataFrame operations inside your streaming program */
val words: DStream[String] = ...
@@ -1454,11 +1433,11 @@ val words: DStream[String] = ...
words.foreachRDD { rdd =>
// Get the singleton instance of SQLContext
- val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext)
+ val sqlContext = SQLContext.getOrCreate(rdd.sparkContext)
import sqlContext.implicits._
- // Convert RDD[String] to RDD[case class] to DataFrame
- val wordsDataFrame = rdd.map(w => Row(w)).toDF()
+ // Convert RDD[String] to DataFrame
+ val wordsDataFrame = rdd.toDF("word")
// Register as table
wordsDataFrame.registerTempTable("words")
@@ -1476,19 +1455,6 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
{% highlight java %}
-/** Lazily instantiated singleton instance of SQLContext */
-class JavaSQLContextSingleton {
- static private transient SQLContext instance = null;
- static public SQLContext getInstance(SparkContext sparkContext) {
- if (instance == null) {
- instance = new SQLContext(sparkContext);
- }
- return instance;
- }
-}
-
-...
-
/** Java Bean class for converting RDD to DataFrame */
public class JavaRow implements java.io.Serializable {
private String word;
@@ -1512,7 +1478,9 @@ words.foreachRDD(
new Function2, Time, Void>() {
@Override
public Void call(JavaRDD rdd, Time time) {
- SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context());
+
+ // Get the singleton instance of SQLContext
+ SQLContext sqlContext = SQLContext.getOrCreate(rdd.context());
// Convert RDD[String] to RDD[case class] to DataFrame
JavaRDD rowRDD = rdd.map(new Function() {
@@ -1581,7 +1549,7 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
-You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember sufficient amount of streaming data such that query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages).
+You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember a sufficient amount of streaming data such that the query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages).
See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more about DataFrames.
@@ -1594,7 +1562,7 @@ You can also easily use machine learning algorithms provided by [MLlib](mllib-gu
## Caching / Persistence
Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is,
-using `persist()` method on a DStream will automatically persist every RDD of that DStream in
+using the `persist()` method on a DStream will automatically persist every RDD of that DStream in
memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple
operations on the same data). For window-based operations like `reduceByWindow` and
`reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true.
@@ -1606,28 +1574,27 @@ default persistence level is set to replicate the data to two nodes for fault-to
Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in
memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More
-information on different persistence levels can be found in
-[Spark Programming Guide](programming-guide.html#rdd-persistence).
+information on different persistence levels can be found in the [Spark Programming Guide](programming-guide.html#rdd-persistence).
***
## Checkpointing
A streaming application must operate 24/7 and hence must be resilient to failures unrelated
to the application logic (e.g., system failures, JVM crashes, etc.). For this to be possible,
-Spark Streaming needs to *checkpoints* enough information to a fault-
+Spark Streaming needs to *checkpoint* enough information to a fault-
tolerant storage system such that it can recover from failures. There are two types of data
that are checkpointed.
- *Metadata checkpointing* - Saving of the information defining the streaming computation to
fault-tolerant storage like HDFS. This is used to recover from failure of the node running the
driver of the streaming application (discussed in detail later). Metadata includes:
- + *Configuration* - The configuration that were used to create the streaming application.
+ + *Configuration* - The configuration that was used to create the streaming application.
+ *DStream operations* - The set of DStream operations that define the streaming application.
+ *Incomplete batches* - Batches whose jobs are queued but have not completed yet.
- *Data checkpointing* - Saving of the generated RDDs to reliable storage. This is necessary
in some *stateful* transformations that combine data across multiple batches. In such
- transformations, the generated RDDs depends on RDDs of previous batches, which causes the length
- of the dependency chain to keep increasing with time. To avoid such unbounded increase in recovery
+ transformations, the generated RDDs depend on RDDs of previous batches, which causes the length
+ of the dependency chain to keep increasing with time. To avoid such unbounded increases in recovery
time (proportional to dependency chain), intermediate RDDs of stateful transformations are periodically
*checkpointed* to reliable storage (e.g. HDFS) to cut off the dependency chains.
@@ -1641,10 +1608,10 @@ transformations are used.
Checkpointing must be enabled for applications with any of the following requirements:
- *Usage of stateful transformations* - If either `updateStateByKey` or `reduceByKeyAndWindow` (with
- inverse function) is used in the application, then the checkpoint directory must be provided for
- allowing periodic RDD checkpointing.
+ inverse function) is used in the application, then the checkpoint directory must be provided to
+ allow for periodic RDD checkpointing.
- *Recovering from failures of the driver running the application* - Metadata checkpoints are used
- for to recover with progress information.
+ to recover with progress information.
Note that simple streaming applications without the aforementioned stateful transformations can be
run without enabling checkpointing. The recovery from driver failures will also be partial in
@@ -1659,7 +1626,7 @@ Checkpointing can be enabled by setting a directory in a fault-tolerant,
reliable file system (e.g., HDFS, S3, etc.) to which the checkpoint information will be saved.
This is done by using `streamingContext.checkpoint(checkpointDirectory)`. This will allow you to
use the aforementioned stateful transformations. Additionally,
-if you want make the application recover from driver failures, you should rewrite your
+if you want to make the application recover from driver failures, you should rewrite your
streaming application to have the following behavior.
+ When the program is being started for the first time, it will create a new StreamingContext,
@@ -1780,18 +1747,17 @@ You can also explicitly create a `StreamingContext` from the checkpoint data and
In addition to using `getOrCreate` one also needs to ensure that the driver process gets
restarted automatically on failure. This can only be done by the deployment infrastructure that is
used to run the application. This is further discussed in the
-[Deployment](#deploying-applications.html) section.
+[Deployment](#deploying-applications) section.
Note that checkpointing of RDDs incurs the cost of saving to reliable storage.
This may cause an increase in the processing time of those batches where RDDs get checkpointed.
Hence, the interval of
checkpointing needs to be set carefully. At small batch sizes (say 1 second), checkpointing every
batch may significantly reduce operation throughput. Conversely, checkpointing too infrequently
-causes the lineage and task sizes to grow which may have detrimental effects. For stateful
+causes the lineage and task sizes to grow, which may have detrimental effects. For stateful
transformations that require RDD checkpointing, the default interval is a multiple of the
batch interval that is at least 10 seconds. It can be set by using
-`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 times of
-sliding interval of a DStream is good setting to try.
+`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 sliding intervals of a DStream is a good setting to try.
***
@@ -1864,17 +1830,17 @@ To run a Spark Streaming applications, you need to have the following.
{:.no_toc}
If a running Spark Streaming application needs to be upgraded with new
-application code, then there are two possible mechanism.
+application code, then there are two possible mechanisms.
- The upgraded Spark Streaming application is started and run in parallel to the existing application.
-Once the new one (receiving the same data as the old one) has been warmed up and ready
+Once the new one (receiving the same data as the old one) has been warmed up and is ready
for prime time, the old one be can be brought down. Note that this can be done for data sources that support
sending the data to two destinations (i.e., the earlier and upgraded applications).
- The existing application is shutdown gracefully (see
[`StreamingContext.stop(...)`](api/scala/index.html#org.apache.spark.streaming.StreamingContext)
or [`JavaStreamingContext.stop(...)`](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html)
-for graceful shutdown options) which ensure data that have been received is completely
+for graceful shutdown options) which ensure data that has been received is completely
processed before shutdown. Then the
upgraded application can be started, which will start processing from the same point where the earlier
application left off. Note that this can be done only with input sources that support source-side buffering
@@ -1909,10 +1875,10 @@ The following two metrics in web UI are particularly important:
to finish.
If the batch processing time is consistently more than the batch interval and/or the queueing
-delay keeps increasing, then it indicates the system is
-not able to process the batches as fast they are being generated and falling behind.
+delay keeps increasing, then it indicates that the system is
+not able to process the batches as fast they are being generated and is falling behind.
In that case, consider
-[reducing](#reducing-the-processing-time-of-each-batch) the batch processing time.
+[reducing](#reducing-the-batch-processing-times) the batch processing time.
The progress of a Spark Streaming program can also be monitored using the
[StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface,
@@ -1923,8 +1889,8 @@ and it is likely to be improved upon (i.e., more information reported) in the fu
***************************************************************************************************
# Performance Tuning
-Getting the best performance of a Spark Streaming application on a cluster requires a bit of
-tuning. This section explains a number of the parameters and configurations that can tuned to
+Getting the best performance out of a Spark Streaming application on a cluster requires a bit of
+tuning. This section explains a number of the parameters and configurations that can be tuned to
improve the performance of you application. At a high level, you need to consider two things:
1. Reducing the processing time of each batch of data by efficiently using cluster resources.
@@ -1934,22 +1900,22 @@ improve the performance of you application. At a high level, you need to conside
## Reducing the Batch Processing Times
There are a number of optimizations that can be done in Spark to minimize the processing time of
-each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section
+each batch. These have been discussed in detail in the [Tuning Guide](tuning.html). This section
highlights some of the most important ones.
### Level of Parallelism in Data Receiving
{:.no_toc}
-Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to deserialized
+Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to be deserialized
and stored in Spark. If the data receiving becomes a bottleneck in the system, then consider
parallelizing the data receiving. Note that each input DStream
creates a single receiver (running on a worker machine) that receives a single stream of data.
Receiving multiple data streams can therefore be achieved by creating multiple input DStreams
and configuring them to receive different partitions of the data stream from the source(s).
For example, a single Kafka input DStream receiving two topics of data can be split into two
-Kafka input streams, each receiving only one topic. This would run two receivers on two workers,
-thus allowing data to be received in parallel, and increasing overall throughput. These multiple
-DStream can be unioned together to create a single DStream. Then the transformations that was
-being applied on the single input DStream can applied on the unified stream. This is done as follows.
+Kafka input streams, each receiving only one topic. This would run two receivers,
+allowing data to be received in parallel, thus increasing overall throughput. These multiple
+DStreams can be unioned together to create a single DStream. Then the transformations that were
+being applied on a single input DStream can be applied on the unified stream. This is done as follows.
@@ -1977,10 +1943,10 @@ Another parameter that should be considered is the receiver's blocking interval,
which is determined by the [configuration parameter](configuration.html#spark-streaming)
`spark.streaming.blockInterval`. For most receivers, the received data is coalesced together into
blocks of data before storing inside Spark's memory. The number of blocks in each batch
-determines the number of tasks that will be used to process those
+determines the number of tasks that will be used to process
the received data in a map-like transformation. The number of tasks per receiver per batch will be
approximately (batch interval / block interval). For example, block interval of 200 ms will
-create 10 tasks per 2 second batches. Too low the number of tasks (that is, less than the number
+create 10 tasks per 2 second batches. If the number of tasks is too low (that is, less than the number
of cores per machine), then it will be inefficient as all available cores will not be used to
process the data. To increase the number of tasks for a given batch interval, reduce the
block interval. However, the recommended minimum value of block interval is about 50 ms,
@@ -1988,7 +1954,7 @@ below which the task launching overheads may be a problem.
An alternative to receiving data with multiple input streams / receivers is to explicitly repartition
the input data stream (using `inputStream.repartition()`).
-This distributes the received batches of data across specified number of machines in the cluster
+This distributes the received batches of data across the specified number of machines in the cluster
before further processing.
### Level of Parallelism in Data Processing
@@ -1996,7 +1962,7 @@ before further processing.
Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the
computation is not high enough. For example, for distributed reduce operations like `reduceByKey`
and `reduceByKeyAndWindow`, the default number of parallel tasks is controlled by
-the`spark.default.parallelism` [configuration property](configuration.html#spark-properties). You
+the `spark.default.parallelism` [configuration property](configuration.html#spark-properties). You
can pass the level of parallelism as an argument (see
[`PairDStreamFunctions`](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions)
documentation), or set the `spark.default.parallelism`
@@ -2004,20 +1970,20 @@ documentation), or set the `spark.default.parallelism`
### Data Serialization
{:.no_toc}
-The overheads of data serialization can be reduce by tuning the serialization formats. In case of streaming, there are two types of data that are being serialized.
+The overheads of data serialization can be reduced by tuning the serialization formats. In the case of streaming, there are two types of data that are being serialized.
-* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is unsufficient to hold all the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format.
+* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is insufficient to hold all of the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format.
-* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operation persist data in memory as they would be processed multiple times. However, unlike Spark, by default RDDs are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) to minimize GC overheads.
+* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operations persist data in memory as they would be processed multiple times. However, unlike the Spark Core default of [StorageLevel.MEMORY_ONLY](api/scala/index.html#org.apache.spark.storage.StorageLevel$), persisted RDDs generated by streaming computations are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) by default to minimize GC overheads.
-In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization)) for more details. Consider registering custom classes, and disabling object reference tracking for Kryo (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)).
+In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization) for more details. For Kryo, consider registering custom classes, and disabling object reference tracking (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)).
-In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads.
+In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of a few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads.
### Task Launching Overheads
{:.no_toc}
If the number of tasks launched per second is high (say, 50 or more per second), then the overhead
-of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second
+of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second
latencies. The overhead can be reduced by the following changes:
* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task
@@ -2036,7 +2002,7 @@ thus allowing sub-second batch size to be viable.
For a Spark Streaming application running on a cluster to be stable, the system should be able to
process data as fast as it is being received. In other words, batches of data should be processed
as fast as they are being generated. Whether this is true for an application can be found by
-[monitoring](#monitoring) the processing times in the streaming web UI, where the batch
+[monitoring](#monitoring-applications) the processing times in the streaming web UI, where the batch
processing time should be less than the batch interval.
Depending on the nature of the streaming
@@ -2049,35 +2015,35 @@ production can be sustained.
A good approach to figure out the right batch size for your application is to test it with a
conservative batch interval (say, 5-10 seconds) and a low data rate. To verify whether the system
-is able to keep up with data rate, you can check the value of the end-to-end delay experienced
+is able to keep up with the data rate, you can check the value of the end-to-end delay experienced
by each processed batch (either look for "Total delay" in Spark driver log4j logs, or use the
[StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener)
interface).
If the delay is maintained to be comparable to the batch size, then system is stable. Otherwise,
if the delay is continuously increasing, it means that the system is unable to keep up and it
therefore unstable. Once you have an idea of a stable configuration, you can try increasing the
-data rate and/or reducing the batch size. Note that momentary increase in the delay due to
-temporary data rate increases maybe fine as long as the delay reduces back to a low value
+data rate and/or reducing the batch size. Note that a momentary increase in the delay due to
+temporary data rate increases may be fine as long as the delay reduces back to a low value
(i.e., less than batch size).
***
## Memory Tuning
-Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail
+Tuning the memory usage and GC behavior of Spark applications has been discussed in great detail
in the [Tuning Guide](tuning.html#memory-tuning). It is strongly recommended that you read that. In this section, we discuss a few tuning parameters specifically in the context of Spark Streaming applications.
-The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes of worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then necessary memory will be low.
+The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on the last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then the necessary memory will be low.
-In general, since the data received through receivers are stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly.
+In general, since the data received through receivers is stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly.
-Another aspect of memory tuning is garbage collection. For a streaming application that require low latency, it is undesirable to have large pauses caused by JVM Garbage Collection.
+Another aspect of memory tuning is garbage collection. For a streaming application that requires low latency, it is undesirable to have large pauses caused by JVM Garbage Collection.
-There are a few parameters that can help you tune the memory usage and GC overheads.
+There are a few parameters that can help you tune the memory usage and GC overheads:
-* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both, the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time.
+* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time.
-* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using window operation of 10 minutes, then Spark Streaming will keep around last 10 minutes of data, and actively throw away older data.
-Data can be retained for longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`.
+* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using a window operation of 10 minutes, then Spark Streaming will keep around the last 10 minutes of data, and actively throw away older data.
+Data can be retained for a longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`.
* **CMS Garbage Collector**: Use of the concurrent mark-and-sweep GC is strongly recommended for keeping GC-related pauses consistently low. Even though concurrent GC is known to reduce the
overall processing throughput of the system, its use is still recommended to achieve more
@@ -2107,18 +2073,18 @@ re-computed from the original fault-tolerant dataset using the lineage of operat
1. Assuming that all of the RDD transformations are deterministic, the data in the final transformed
RDD will always be the same irrespective of failures in the Spark cluster.
-Spark operates on data on fault-tolerant file systems like HDFS or S3. Hence,
+Spark operates on data in fault-tolerant file systems like HDFS or S3. Hence,
all of the RDDs generated from the fault-tolerant data are also fault-tolerant. However, this is not
the case for Spark Streaming as the data in most cases is received over the network (except when
`fileStream` is used). To achieve the same fault-tolerance properties for all of the generated RDDs,
the received data is replicated among multiple Spark executors in worker nodes in the cluster
(default replication factor is 2). This leads to two kinds of data in the
-system that needs to recovered in the event of failures:
+system that need to recovered in the event of failures:
1. *Data received and replicated* - This data survives failure of a single worker node as a copy
- of it exists on one of the nodes.
+ of it exists on one of the other nodes.
1. *Data received but buffered for replication* - Since this is not replicated,
- the only way to recover that data is to get it again from the source.
+ the only way to recover this data is to get it again from the source.
Furthermore, there are two kinds of failures that we should be concerned about:
@@ -2145,13 +2111,13 @@ In any stream processing system, broadly speaking, there are three steps in proc
1. *Receiving the data*: The data is received from sources using Receivers or otherwise.
-1. *Transforming the data*: The data received data is transformed using DStream and RDD transformations.
+1. *Transforming the data*: The received data is transformed using DStream and RDD transformations.
1. *Pushing out the data*: The final transformed data is pushed out to external systems like file systems, databases, dashboards, etc.
-If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming.
+If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide an exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming.
-1. *Receiving the data*: Different input sources provided different guarantees. This is discussed in detail in the next subsection.
+1. *Receiving the data*: Different input sources provide different guarantees. This is discussed in detail in the next subsection.
1. *Transforming the data*: All data that has been received will be processed _exactly once_, thanks to the guarantees that RDDs provide. Even if there are failures, as long as the received input data is accessible, the final transformed RDDs will always have the same contents.
@@ -2163,9 +2129,9 @@ Different input sources provide different guarantees, ranging from _at-least onc
### With Files
{:.no_toc}
-If all of the input data is already present in a fault-tolerant files system like
-HDFS, Spark Streaming can always recover from any failure and process all the data. This gives
-*exactly-once* semantics, that all the data will be processed exactly once no matter what fails.
+If all of the input data is already present in a fault-tolerant file system like
+HDFS, Spark Streaming can always recover from any failure and process all of the data. This gives
+*exactly-once* semantics, meaning all of the data will be processed exactly once no matter what fails.
### With Receiver-based Sources
{:.no_toc}
@@ -2174,21 +2140,21 @@ scenario and the type of receiver.
As we discussed [earlier](#receiver-reliability), there are two types of receivers:
1. *Reliable Receiver* - These receivers acknowledge reliable sources only after ensuring that
- the received data has been replicated. If such a receiver fails,
- the buffered (unreplicated) data does not get acknowledged to the source. If the receiver is
- restarted, the source will resend the data, and therefore no data will be lost due to the failure.
-1. *Unreliable Receiver* - Such receivers can lose data when they fail due to worker
- or driver failures.
+ the received data has been replicated. If such a receiver fails, the source will not receive
+ acknowledgment for the buffered (unreplicated) data. Therefore, if the receiver is
+ restarted, the source will resend the data, and no data will be lost due to the failure.
+1. *Unreliable Receiver* - Such receivers do *not* send acknowledgment and therefore *can* lose
+ data when they fail due to worker or driver failures.
Depending on what type of receivers are used we achieve the following semantics.
If a worker node fails, then there is no data loss with reliable receivers. With unreliable
receivers, data received but not replicated can get lost. If the driver node fails,
-then besides these losses, all the past data that was received and replicated in memory will be
+then besides these losses, all of the past data that was received and replicated in memory will be
lost. This will affect the results of the stateful transformations.
To avoid this loss of past received data, Spark 1.2 introduced _write
-ahead logs_ which saves the received data to fault-tolerant storage. With the [write ahead logs
-enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides at-least once guarantee.
+ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs
+enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee.
The following table summarizes the semantics under failures:
@@ -2234,7 +2200,7 @@ The following table summarizes the semantics under failures:
### With Kafka Direct API
{:.no_toc}
-In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark 1.3) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html).
+In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark {{site.SPARK_VERSION_SHORT}}) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html).
## Semantics of output operations
{:.no_toc}
@@ -2248,9 +2214,16 @@ additional effort may be necessary to achieve exactly-once semantics. There are
- *Transactional updates*: All updates are made transactionally so that updates are made exactly once atomically. One way to do this would be the following.
- - Use the batch time (available in `foreachRDD`) and the partition index of the transformed RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application.
- - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else if this was already committed, skip the update.
+ - Use the batch time (available in `foreachRDD`) and the partition index of the RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application.
+ - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else, if this was already committed, skip the update.
+ dstream.foreachRDD { (rdd, time) =>
+ rdd.foreachPartition { partitionIterator =>
+ val partitionId = TaskContext.get.partitionId()
+ val uniqueId = generateUniqueId(time.milliseconds, partitionId)
+ // use this uniqueId to transactionally commit the data in partitionIterator
+ }
+ }
***************************************************************************************************
***************************************************************************************************
@@ -2325,7 +2298,7 @@ package and renamed for better clarity.
- Java docs
* [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html),
[JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and
- [PairJavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/PairJavaDStream.html)
+ [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html)
* [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html),
[FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html),
[KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html)
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index ee0904c9e5d5..56087499464e 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -51,7 +51,7 @@
raw_input = input
xrange = range
-SPARK_EC2_VERSION = "1.3.1"
+SPARK_EC2_VERSION = "1.4.0"
SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
VALID_SPARK_VERSIONS = set([
@@ -70,6 +70,7 @@
"1.2.1",
"1.3.0",
"1.3.1",
+ "1.4.0",
])
SPARK_TACHYON_MAP = {
@@ -82,6 +83,7 @@
"1.2.1": "0.5.0",
"1.3.0": "0.5.0",
"1.3.1": "0.5.0",
+ "1.4.0": "0.6.4",
}
DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION
@@ -89,7 +91,7 @@
# Default location to get the spark-ec2 scripts (and ami-list) from
DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2"
-DEFAULT_SPARK_EC2_BRANCH = "branch-1.3"
+DEFAULT_SPARK_EC2_BRANCH = "branch-1.4"
def setup_external_libs(libs):
@@ -219,7 +221,8 @@ def parse_args():
"(default: %default).")
parser.add_option(
"--hadoop-major-version", default="1",
- help="Major version of Hadoop (default: %default)")
+ help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " +
+ "(Hadoop 2.4.0) (default: %default)")
parser.add_option(
"-D", metavar="[ADDRESS:]PORT", dest="proxy_port",
help="Use SSH dynamic port forwarding to create a SOCKS proxy at " +
@@ -271,7 +274,8 @@ def parse_args():
help="Launch fresh slaves, but use an existing stopped master if possible")
parser.add_option(
"--worker-instances", type="int", default=1,
- help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)")
+ help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " +
+ "is used as Hadoop major version (default: %default)")
parser.add_option(
"--master-opts", type="string", default="",
help="Extra options to give to master through SPARK_MASTER_OPTS variable " +
@@ -761,6 +765,10 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
if opts.ganglia:
modules.append('ganglia')
+ # Clear SPARK_WORKER_INSTANCES if running on YARN
+ if opts.hadoop_major_version == "yarn":
+ opts.worker_instances = ""
+
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format(
@@ -998,6 +1006,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes]
slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes]
+ worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else ""
template_vars = {
"master_list": '\n'.join(master_addresses),
"active_master": active_master,
@@ -1011,7 +1020,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
"spark_version": spark_v,
"tachyon_version": tachyon_v,
"hadoop_major_version": opts.hadoop_major_version,
- "spark_worker_instances": "%d" % opts.worker_instances,
+ "spark_worker_instances": worker_instances_str,
"spark_master_opts": opts.master_opts
}
diff --git a/examples/pom.xml b/examples/pom.xml
index e4efee7b5e64..e6884b09dca9 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 29158d5c8565..dac649d1d5ae 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -97,7 +97,7 @@ public static void main(String[] args) {
DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
- // LogisticRegression.transform will only use the 'features' column.
+ // LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
DataFrame results = model2.transform(test);
diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py
index 1456c8731284..0ea7cfb7025a 100755
--- a/examples/src/main/python/kmeans.py
+++ b/examples/src/main/python/kmeans.py
@@ -68,7 +68,7 @@ def closestPoint(p, centers):
closest = data.map(
lambda p: (closestPoint(p, kPoints), (p, 1)))
pointStats = closest.reduceByKey(
- lambda (p1, c1), (p2, c2): (p1 + p2, c1 + c2))
+ lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
newPoints = pointStats.map(
lambda st: (st[0], st[1][0] / st[1][1])).collect()
diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py
new file mode 100644
index 000000000000..f0ca97c72494
--- /dev/null
+++ b/examples/src/main/python/ml/cross_validator.py
@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.ml import Pipeline
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.evaluation import BinaryClassificationEvaluator
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating model selection using CrossValidator.
+This example also demonstrates how Pipelines are Estimators.
+Run with:
+
+ bin/spark-submit examples/src/main/python/ml/cross_validator.py
+"""
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="CrossValidatorExample")
+ sqlContext = SQLContext(sc)
+
+ # Prepare training documents, which are labeled.
+ LabeledDocument = Row("id", "text", "label")
+ training = sc.parallelize([(0, "a b c d e spark", 1.0),
+ (1, "b d", 0.0),
+ (2, "spark f g h", 1.0),
+ (3, "hadoop mapreduce", 0.0),
+ (4, "b spark who", 1.0),
+ (5, "g d a y", 0.0),
+ (6, "spark fly", 1.0),
+ (7, "was mapreduce", 0.0),
+ (8, "e spark program", 1.0),
+ (9, "a e c l", 0.0),
+ (10, "spark compile", 1.0),
+ (11, "hadoop software", 0.0)
+ ]) \
+ .map(lambda x: LabeledDocument(*x)).toDF()
+
+ # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
+ tokenizer = Tokenizer(inputCol="text", outputCol="words")
+ hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
+ lr = LogisticRegression(maxIter=10)
+ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+
+ # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
+ # This will allow us to jointly choose parameters for all Pipeline stages.
+ # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ # We use a ParamGridBuilder to construct a grid of parameters to search over.
+ # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
+ # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
+ paramGrid = ParamGridBuilder() \
+ .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
+ .addGrid(lr.regParam, [0.1, 0.01]) \
+ .build()
+
+ crossval = CrossValidator(estimator=pipeline,
+ estimatorParamMaps=paramGrid,
+ evaluator=BinaryClassificationEvaluator(),
+ numFolds=2) # use 3+ folds in practice
+
+ # Run cross-validation, and choose the best set of parameters.
+ cvModel = crossval.fit(training)
+
+ # Prepare test documents, which are unlabeled.
+ Document = Row("id", "text")
+ test = sc.parallelize([(4L, "spark i j k"),
+ (5L, "l m n"),
+ (6L, "mapreduce spark"),
+ (7L, "apache hadoop")]) \
+ .map(lambda x: Document(*x)).toDF()
+
+ # Make predictions on test documents. cvModel uses the best model found (lrModel).
+ prediction = cvModel.transform(test)
+ selected = prediction.select("id", "text", "probability", "prediction")
+ for row in selected.collect():
+ print(row)
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py
new file mode 100644
index 000000000000..6446f0fe5eea
--- /dev/null
+++ b/examples/src/main/python/ml/gradient_boosted_trees.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import GBTClassifier
+from pyspark.ml.feature import StringIndexer
+from pyspark.ml.regression import GBTRegressor
+from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline.
+Note: GBTClassifier only supports binary classification currently
+Run with:
+ bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py
+"""
+
+
+def testClassification(train, test):
+ # Train a GradientBoostedTrees model.
+
+ rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel")
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = BinaryClassificationMetrics(predictionAndLabels)
+ print("AUC %.3f" % metrics.areaUnderROC)
+
+
+def testRegression(train, test):
+ # Train a GradientBoostedTrees model.
+
+ rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel")
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = RegressionMetrics(predictionAndLabels)
+ print("rmse %.3f" % metrics.rootMeanSquaredError)
+ print("r2 %.3f" % metrics.r2)
+ print("mae %.3f" % metrics.meanAbsoluteError)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: gradient_boosted_trees", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonGBTExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [train, test] = td.randomSplit([0.7, 0.3])
+ testClassification(train, test)
+ testRegression(train, test)
+ sc.stop()
diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py
new file mode 100644
index 000000000000..c7730e1bfacd
--- /dev/null
+++ b/examples/src/main/python/ml/random_forest_example.py
@@ -0,0 +1,87 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import RandomForestClassifier
+from pyspark.ml.feature import StringIndexer
+from pyspark.ml.regression import RandomForestRegressor
+from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating a RandomForest Classification/Regression Pipeline.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/random_forest_example.py
+"""
+
+
+def testClassification(train, test):
+ # Train a RandomForest model.
+ # Setting featureSubsetStrategy="auto" lets the algorithm choose.
+ # Note: Use larger numTrees in practice.
+
+ rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4)
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = MulticlassMetrics(predictionAndLabels)
+ print("weighted f-measure %.3f" % metrics.weightedFMeasure())
+ print("precision %s" % metrics.precision())
+ print("recall %s" % metrics.recall())
+
+
+def testRegression(train, test):
+ # Train a RandomForest model.
+ # Note: Use larger numTrees in practice.
+
+ rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4)
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = RegressionMetrics(predictionAndLabels)
+ print("rmse %.3f" % metrics.rootMeanSquaredError)
+ print("r2 %.3f" % metrics.r2)
+ print("mae %.3f" % metrics.meanAbsoluteError)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: random_forest_example", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonRandomForestExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [train, test] = td.randomSplit([0.7, 0.3])
+ testClassification(train, test)
+ testRegression(train, test)
+ sc.stop()
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
new file mode 100644
index 000000000000..a9f29dab2d60
--- /dev/null
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import pprint
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import LogisticRegression
+from pyspark.mllib.linalg import DenseVector
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.sql import SQLContext
+
+"""
+A simple example demonstrating ways to specify parameters for Estimators and Transformers.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/simple_params_example.py
+"""
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: simple_params_example", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonSimpleParamsExample")
+ sqlContext = SQLContext(sc)
+
+ # prepare training data.
+ # We create an RDD of LabeledPoints and convert them into a DataFrame.
+ # A LabeledPoint is an Object with two fields named label and features
+ # and Spark SQL identifies these fields and creates the schema appropriately.
+ training = sc.parallelize([
+ LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
+ LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),
+ LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])),
+ LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF()
+
+ # Create a LogisticRegression instance with maxIter = 10.
+ # This instance is an Estimator.
+ lr = LogisticRegression(maxIter=10)
+ # Print out the parameters, documentation, and any default values.
+ print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
+
+ # We may also set parameters using setter methods.
+ lr.setRegParam(0.01)
+
+ # Learn a LogisticRegression model. This uses the parameters stored in lr.
+ model1 = lr.fit(training)
+
+ # Since model1 is a Model (i.e., a Transformer produced by an Estimator),
+ # we can view the parameters it used during fit().
+ # This prints the parameter (name: value) pairs, where names are unique IDs for this
+ # LogisticRegression instance.
+ print("Model 1 was fit using parameters:\n")
+ pprint.pprint(model1.extractParamMap())
+
+ # We may alternatively specify parameters using a parameter map.
+ # paramMap overrides all lr parameters set earlier.
+ paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"}
+
+ # Now learn a new model using the new parameters.
+ model2 = lr.fit(training, paramMap)
+ print("Model 2 was fit using parameters:\n")
+ pprint.pprint(model2.extractParamMap())
+
+ # prepare test data.
+ test = sc.parallelize([
+ LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])),
+ LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])),
+ LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF()
+
+ # Make predictions on test data using the Transformer.transform() method.
+ # LogisticRegressionModel.transform will only use the 'features' column.
+ # Note that model2.transform() outputs a 'myProbability' column instead of the usual
+ # 'probability' column since we renamed the lr.probabilityCol parameter previously.
+ result = model2.transform(test) \
+ .select("features", "label", "myProbability", "prediction") \
+ .collect()
+
+ for row in result:
+ print("features=%s,label=%s -> prob=%s, prediction=%s"
+ % (row.features, row.label, row.myProbability, row.prediction))
+
+ sc.stop()
diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py
index 96ddac761d69..e1fd85b082c0 100644
--- a/examples/src/main/python/parquet_inputformat.py
+++ b/examples/src/main/python/parquet_inputformat.py
@@ -51,7 +51,7 @@
parquet_rdd = sc.newAPIHadoopFile(
path,
- 'parquet.avro.AvroParquetInputFormat',
+ 'org.apache.parquet.avro.AvroParquetInputFormat',
'java.lang.Void',
'org.apache.avro.generic.IndexedRecord',
valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter')
diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
index 32e02eab8b03..75c82117cbad 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
@@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._
/**
* Executes a roll up-style query against Apache logs.
- *
+ *
* Usage: LogQuery [logFile]
*/
object LogQuery {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
new file mode 100644
index 000000000000..b54466fd48bc
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
+import org.apache.spark.sql.DataFrame
+
+/**
+ * An example runner for linear regression with elastic-net (mixing L1/L2) regularization.
+ * Run with
+ * {{{
+ * bin/run-example ml.LinearRegressionExample [options]
+ * }}}
+ * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be
+ * trained by
+ * {{{
+ * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \
+ * data/mllib/sample_linear_regression_data.txt
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object LinearRegressionExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ regParam: Double = 0.0,
+ elasticNetParam: Double = 0.0,
+ maxIter: Int = 100,
+ tol: Double = 1E-6,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("LinearRegressionExample") {
+ head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.")
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Double]("elasticNetParam")
+ .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " +
+ s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " +
+ s"L1 and L2, default: ${defaultParams.elasticNetParam}")
+ .action((x, c) => c.copy(elasticNetParam = x))
+ opt[Int]("maxIter")
+ .text(s"maximum number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("tol")
+ .text(s"the convergence tolerance of iterations, Smaller value will lead " +
+ s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}")
+ .action((x, c) => c.copy(tol = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"LinearRegressionExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, "regression", params.fracTest)
+
+ val lir = new LinearRegression()
+ .setFeaturesCol("features")
+ .setLabelCol("label")
+ .setRegParam(params.regParam)
+ .setElasticNetParam(params.elasticNetParam)
+ .setMaxIter(params.maxIter)
+ .setTol(params.tol)
+
+ // Train the model
+ val startTime = System.nanoTime()
+ val lirModel = lir.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Print the weights and intercept for linear regression.
+ println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}")
+
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label")
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label")
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
new file mode 100644
index 000000000000..3cf193f353fb
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.feature.StringIndexer
+import org.apache.spark.sql.DataFrame
+
+/**
+ * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization.
+ * Run with
+ * {{{
+ * bin/run-example ml.LogisticRegressionExample [options]
+ * }}}
+ * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be
+ * trained by
+ * {{{
+ * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \
+ * data/mllib/sample_libsvm_data.txt
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object LogisticRegressionExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ regParam: Double = 0.0,
+ elasticNetParam: Double = 0.0,
+ maxIter: Int = 100,
+ fitIntercept: Boolean = true,
+ tol: Double = 1E-6,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("LogisticRegressionExample") {
+ head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.")
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Double]("elasticNetParam")
+ .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " +
+ s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " +
+ s"L1 and L2, default: ${defaultParams.elasticNetParam}")
+ .action((x, c) => c.copy(elasticNetParam = x))
+ opt[Int]("maxIter")
+ .text(s"maximum number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Boolean]("fitIntercept")
+ .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}")
+ .action((x, c) => c.copy(fitIntercept = x))
+ opt[Double]("tol")
+ .text(s"the convergence tolerance of iterations, Smaller value will lead " +
+ s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}")
+ .action((x, c) => c.copy(tol = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"LogisticRegressionExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, "classification", params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol("indexedLabel")
+ stages += labelIndexer
+
+ val lor = new LogisticRegression()
+ .setFeaturesCol("features")
+ .setLabelCol("indexedLabel")
+ .setRegParam(params.regParam)
+ .setElasticNetParam(params.elasticNetParam)
+ .setMaxIter(params.maxIter)
+ .setTol(params.tol)
+
+ stages += lor
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel]
+ // Print the weights and intercept for logistic regression.
+ println(s"Weights: ${lorModel.weights} Intercept: ${lorModel.intercept}")
+
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel")
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel")
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index e8a991f50e33..a0561e2573fc 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -87,7 +87,7 @@ object SimpleParamsExample {
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
// Make predictions on test data using the Transformer.transform() method.
- // LogisticRegression.transform will only use the 'features' column.
+ // LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test.toDF())
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index b0613632c994..3381941673db 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,7 +22,6 @@ import scala.language.reflectiveCalls
import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -354,7 +353,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
+ *
+ * This is just for demo purpose. In general, don't copy this code because it is NOT efficient
+ * due to the use of structural types, which leads to one reflection call per record.
*/
+ // scalastyle:off structural.type
private[mllib] def meanSquaredError(
model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
@@ -363,4 +366,5 @@ object DecisionTreeRunner {
err * err
}.mean()
}
+ // scalastyle:on structural.type
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
index 9a1aab036aa0..f8c71ccabc43 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
@@ -41,22 +41,22 @@ object DenseGaussianMixture {
private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) {
val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
val ctx = new SparkContext(conf)
-
+
val data = ctx.textFile(inputFile).map { line =>
Vectors.dense(line.trim.split(' ').map(_.toDouble))
}.cache()
-
+
val clusters = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
.setMaxIterations(maxIterations)
.run(data)
-
+
for (i <- 0 until clusters.k) {
- println("weight=%f\nmu=%s\nsigma=\n%s\n" format
+ println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
-
+
println("Cluster labels (first <= 100):")
val clusterLabels = clusters.predict(data)
clusterLabels.take(100).foreach { x =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
index b336751d8161..813c8554f519 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
@@ -40,7 +40,7 @@ object MQTTPublisher {
StreamingExamples.setStreamingLogLevels()
val Seq(brokerUrl, topic) = args.toSeq
-
+
var client: MqttClient = null
try {
@@ -59,10 +59,10 @@ object MQTTPublisher {
println(s"Published data. topic: ${msgtopic.getName()}; Message: $message")
} catch {
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
- Thread.sleep(10)
+ Thread.sleep(10)
println("Queue is full, wait for to consume data from the message queue")
- }
- }
+ }
+ }
} catch {
case e: MqttException => println("Exception Caught: " + e)
} finally {
@@ -107,7 +107,7 @@ object MQTTWordCount {
val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2)
val words = lines.flatMap(x => x.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
-
+
wordCounts.print()
ssc.start()
ssc.awaitTermination()
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 1f3e619d97a2..7a7dccc3d092 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -42,15 +42,46 @@
org.apache.flumeflume-ng-sdk
+
+
+
+ com.google.guava
+ guava
+
+
+
+ org.apache.thrift
+ libthrift
+
+ org.apache.flumeflume-ng-core
+
+
+ com.google.guava
+ guava
+
+
+ org.apache.thrift
+ libthrift
+
+ org.scala-langscala-library
+
+
+ com.google.guava
+ guava
+ test
+
+
+
+
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
index fd01807fc3ac..dc2a4ab138e1 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
@@ -21,7 +21,6 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
-import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.flume.Channel
import org.apache.commons.lang3.RandomStringUtils
@@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils
private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel,
val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging {
val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads,
- new ThreadFactoryBuilder().setDaemon(true)
- .setNameFormat("Spark Sink Processor Thread - %d").build()))
+ new SparkSinkThreadFactory("Spark Sink Processor Thread - %d")))
// Protected by `sequenceNumberToProcessor`
private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]()
// This sink will not persist sequence numbers and reuses them if it gets restarted.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
similarity index 51%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
rename to external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
index a4a3a66b8b22..845fc8debda7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
@@ -14,23 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.streaming.flume.sink
-package org.apache.spark.sql.catalyst.optimizer
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
+import java.util.concurrent.ThreadFactory
+import java.util.concurrent.atomic.AtomicLong
/**
- * Overrides our expression evaluation tests and reruns them after optimization has occured. This
- * is to ensure that constant folding and other optimizations do not break anything.
+ * Thread factory that generates daemon threads with a specified name format.
*/
-class ExpressionOptimizationSuite extends ExpressionEvaluationSuite {
- override def checkEvaluation(
- expression: Expression,
- expected: Any,
- inputRow: Row = EmptyRow): Unit = {
- val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = DefaultOptimizer.execute(plan)
- super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow)
+private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory {
+
+ private val threadId = new AtomicLong()
+
+ override def newThread(r: Runnable): Thread = {
+ val t = new Thread(r, nameFormat.format(threadId.incrementAndGet()))
+ t.setDaemon(true)
+ t
}
+
}
diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
index 650b2fbe1c14..fa43629d4977 100644
--- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
+++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
@@ -24,16 +24,24 @@ import scala.collection.JavaConversions._
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
-import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.flume.Context
import org.apache.flume.channel.MemoryChannel
import org.apache.flume.event.EventBuilder
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
+
+// Due to MNG-1378, there is not a way to include test dependencies transitively.
+// We cannot include Spark core tests as a dependency here because it depends on
+// Spark core main, which has too many dependencies to require here manually.
+// For this reason, we continue to use FunSuite and ignore the scalastyle checks
+// that fail if this is detected.
+//scalastyle:off
import org.scalatest.FunSuite
class SparkSinkSuite extends FunSuite {
+//scalastyle:on
+
val eventsPerBatch = 1000
val channelCapacity = 5000
@@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite {
count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = {
(1 to count).map(_ => {
- lazy val channelFactoryExecutor =
- Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true).
- setNameFormat("Flume Receiver Channel Thread - %d").build())
+ lazy val channelFactoryExecutor = Executors.newCachedThreadPool(
+ new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d"))
lazy val channelFactory =
new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor)
val transceiver = new NettyTransceiver(address, channelFactory)
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 8df7edbdcad3..14f7daaf417e 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-streaming-flume-sink_${scala.binary.version}
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index 60e2994431b3..1e32a365a1ee 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -152,9 +152,9 @@ class FlumeReceiver(
val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(),
Executors.newCachedThreadPool())
val channelPipelineFactory = new CompressionChannelPipelineFactory()
-
+
new NettyServer(
- responder,
+ responder,
new InetSocketAddress(host, port),
channelFactory,
channelPipelineFactory,
@@ -188,12 +188,12 @@ class FlumeReceiver(
override def preferredLocation: Option[String] = Option(host)
- /** A Netty Pipeline factory that will decompress incoming data from
+ /** A Netty Pipeline factory that will decompress incoming data from
* and the Netty client and compress data going back to the client.
*
* The compression on the return is required because Flume requires
- * a successful response to indicate it can remove the event/batch
- * from the configured channel
+ * a successful response to indicate it can remove the event/batch
+ * from the configured channel
*/
private[streaming]
class CompressionChannelPipelineFactory extends ChannelPipelineFactory {
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
index 92fa5b41be89..583e7dca317a 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
@@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver(
}
/**
- * A wrapper around the transceiver and the Avro IPC API.
+ * A wrapper around the transceiver and the Avro IPC API.
* @param transceiver The transceiver to use for communication with Flume
* @param client The client that the callbacks are received on.
*/
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index 93afe50c2134..d772b9ca9b57 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -31,16 +31,16 @@ import org.apache.flume.conf.Configurables
import org.apache.flume.event.EventBuilder
import org.scalatest.concurrent.Eventually._
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
import org.apache.spark.streaming.flume.sink._
import org.apache.spark.util.{ManualClock, Utils}
-class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging {
+class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {
val batchCount = 5
val eventsPerBatch = 100
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 39e6754c81db..c926359987d8 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -35,15 +35,15 @@ import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
import org.apache.spark.util.Utils
-class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
+class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
var ssc: StreamingContext = null
@@ -138,7 +138,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
val status = client.appendBatch(inputEvents.toList)
status should be (avro.Status.OK)
}
-
+
eventually(timeout(10 seconds), interval(100 milliseconds)) {
val outputEvents = outputBuffer.flatten.map { _.event }
outputEvents.foreach {
diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml
index 0b79f47647f6..8059c443827e 100644
--- a/external/kafka-assembly/pom.xml
+++ b/external/kafka-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 243ce6eaca65..ded863bd985e 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.kafkakafka_${scala.binary.version}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index 6cf254a7b69c..65d51d87f848 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
r.flatMap { tm: TopicMetadata =>
tm.partitionsMetadata.map { pm: PartitionMetadata =>
TopicAndPartition(tm.topic, pm.partitionId)
- }
+ }
}
}
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
index 6dc4e9517d5a..b608b7595272 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -195,6 +195,8 @@ private class KafkaTestUtils extends Logging {
val props = new Properties()
props.put("metadata.broker.list", brokerAddress)
props.put("serializer.class", classOf[StringEncoder].getName)
+ // wait for all in-sync replicas to ack sends
+ props.put("request.required.acks", "-1")
props
}
@@ -229,21 +231,6 @@ private class KafkaTestUtils extends Logging {
tryAgain(1)
}
- /** Wait until the leader offset for the given topic/partition equals the specified offset */
- def waitUntilLeaderOffset(
- topic: String,
- partition: Int,
- offset: Long): Unit = {
- eventually(Time(10000), Time(100)) {
- val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress))
- val tp = TopicAndPartition(topic, partition)
- val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset
- assert(
- llo == offset,
- s"$topic $partition $offset not reached after timeout")
- }
- }
-
private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match {
case Some(partitionState) =>
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 8be2707528d9..0b8a391a2c56 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -315,7 +315,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -363,7 +363,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -427,7 +427,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -489,7 +489,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
index 4c1d6a03eb2b..c0669fb33665 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
@@ -18,9 +18,7 @@
package org.apache.spark.streaming.kafka;
import java.io.Serializable;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Arrays;
+import java.util.*;
import scala.Tuple2;
@@ -116,7 +114,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception
);
JavaDStream unifiedStream = stream1.union(stream2);
- final HashSet result = new HashSet();
+ final Set result = Collections.synchronizedSet(new HashSet());
unifiedStream.foreachRDD(
new Function, Void>() {
@Override
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
index 5cf379635354..a9dc6e50613c 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
@@ -72,9 +72,6 @@ public void testKafkaRDD() throws InterruptedException {
HashMap kafkaParams = new HashMap();
kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress());
- kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length);
- kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length);
-
OffsetRange[] offsetRanges = {
OffsetRange.create(topic1, 0, 0, 1),
OffsetRange.create(topic2, 0, 0, 1)
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
index 540f4ceabab4..e4c659215b76 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
@@ -18,9 +18,7 @@
package org.apache.spark.streaming.kafka;
import java.io.Serializable;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Random;
+import java.util.*;
import scala.Tuple2;
@@ -94,7 +92,7 @@ public void testKafkaStream() throws InterruptedException {
topics,
StorageLevel.MEMORY_ONLY_SER());
- final HashMap result = new HashMap();
+ final Map result = Collections.synchronizedMap(new HashMap());
JavaDStream words = stream.map(
new Function, String>() {
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index b6d314dfc778..212eb35c61b6 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -28,10 +28,10 @@ import scala.language.postfixOps
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
import kafka.serializer.StringDecoder
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
-import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
import org.apache.spark.streaming.dstream.DStream
@@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.Utils
class DirectKafkaStreamSuite
- extends FunSuite
+ extends SparkFunSuite
with BeforeAndAfter
with BeforeAndAfterAll
with Eventually
@@ -99,7 +99,8 @@ class DirectKafkaStreamSuite
ssc, kafkaParams, topics)
}
- val allReceived = new ArrayBuffer[(String, String)]
+ val allReceived =
+ new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)]
stream.foreachRDD { rdd =>
// Get the offset ranges in the RDD
@@ -162,7 +163,7 @@ class DirectKafkaStreamSuite
"Start offset not from latest"
)
- val collectedData = new mutable.ArrayBuffer[String]()
+ val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String]
stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() }
ssc.start()
val newData = Map("b" -> 10)
@@ -208,7 +209,7 @@ class DirectKafkaStreamSuite
"Start offset not from latest"
)
- val collectedData = new mutable.ArrayBuffer[String]()
+ val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String]
stream.foreachRDD { rdd => collectedData ++= rdd.collect() }
ssc.start()
val newData = Map("b" -> 10)
@@ -324,7 +325,8 @@ class DirectKafkaStreamSuite
ssc, kafkaParams, Set(topic))
}
- val allReceived = new ArrayBuffer[(String, String)]
+ val allReceived =
+ new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)]
stream.foreachRDD { rdd => allReceived ++= rdd.collect() }
ssc.start()
@@ -350,8 +352,8 @@ class DirectKafkaStreamSuite
}
object DirectKafkaStreamSuite {
- val collectedData = new mutable.ArrayBuffer[String]()
- var total = -1L
+ val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String]
+ @volatile var total = -1L
class InputInfoCollector extends StreamingListener {
val numRecordsSubmitted = new AtomicLong(0L)
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
index 7fb841b79cb6..d66830cbacde 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka
import scala.util.Random
import kafka.common.TopicAndPartition
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll {
+import org.apache.spark.SparkFunSuite
+
+class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll {
private val topic = "kcsuitetopic" + Random.nextInt(10000)
private val topicAndPartition = TopicAndPartition(topic, 0)
private var kc: KafkaCluster = null
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index 3c875cb76651..d5baf5fd8999 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -22,11 +22,11 @@ import scala.util.Random
import kafka.serializer.StringDecoder
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
import org.apache.spark._
-class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
+class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
private var kafkaTestUtils: KafkaTestUtils = _
@@ -61,8 +61,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"group.id" -> s"test-consumer-${Random.nextInt}")
- kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size)
-
val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
@@ -86,7 +84,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
// this is the "lots of messages" case
kafkaTestUtils.sendMessages(topic, sent)
val sentCount = sent.values.sum
- kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount)
// rdd defined from leaders after sending messages, should get the number sent
val rdd = getRdd(kc, Set(topic))
@@ -113,7 +110,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
val sentOnlyOne = Map("d" -> 1)
kafkaTestUtils.sendMessages(topic, sentOnlyOne)
- kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1)
assert(rdd2.isDefined)
assert(rdd2.get.count === 0, "got messages when there shouldn't be any")
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
index 24699dfc33ad..797b07f80d8e 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
@@ -23,14 +23,14 @@ import scala.language.postfixOps
import scala.util.Random
import kafka.serializer.StringDecoder
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll {
+class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
private var ssc: StreamingContext = _
private var kafkaTestUtils: KafkaTestUtils = _
@@ -65,7 +65,7 @@ class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll {
val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY)
- val result = new mutable.HashMap[String, Long]()
+ val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long]
stream.map(_._2).countByValue().foreachRDD { r =>
val ret = r.collect()
ret.toMap.foreach { kv =>
@@ -77,10 +77,7 @@ class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll {
ssc.start()
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
- assert(sent.size === result.size)
- sent.keys.foreach { k =>
- assert(sent(k) === result(k).toInt)
- }
+ assert(sent === result)
}
}
}
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
index 38548dd73b82..80e2df62de3f 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -26,15 +26,15 @@ import scala.util.Random
import kafka.serializer.StringDecoder
import kafka.utils.{ZKGroupTopicDirs, ZkUtils}
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.util.Utils
-class ReliableKafkaStreamSuite extends FunSuite
+class ReliableKafkaStreamSuite extends SparkFunSuite
with BeforeAndAfterAll with BeforeAndAfter with Eventually {
private val sparkConf = new SparkConf()
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 98f95a9a64fa..0e41e5781784 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.eclipse.pahoorg.eclipse.paho.client.mqttv3
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index a19a72c58a70..c4bf5aa7869b 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
@@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils
-class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
+class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
private val batchDuration = Milliseconds(500)
private val master = "local[2]"
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 8b6a8959ac4c..178ae8de13b5 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.twitter4jtwitter4j-stream
diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
index 9ee57d7581d8..d9acb568879f 100644
--- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
+++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
@@ -18,16 +18,16 @@
package org.apache.spark.streaming.twitter
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import twitter4j.Status
import twitter4j.auth.{NullAuthorization, Authorization}
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging {
+class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {
val batchDuration = Seconds(1)
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index a50d378b3433..37bfd10d4366 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ ${akka.group}akka-zeromq_${scala.binary.version}
diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
index a7566e733d89..35d2e62c6848 100644
--- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
+++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq
import akka.actor.SupervisorStrategy
import akka.util.ByteString
import akka.zeromq.Subscribe
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-class ZeroMQStreamSuite extends FunSuite {
+class ZeroMQStreamSuite extends SparkFunSuite {
val batchDuration = Seconds(1)
diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
index 4351a8a12fe2..f138251748c9 100644
--- a/extras/java8-tests/pom.xml
+++ b/extras/java8-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index 25847a1b33d9..c6f60bc90743 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -40,6 +40,13 @@
spark-streaming_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-streaming_${scala.binary.version}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
index 97c347604928..be8b62d3cc6b 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
@@ -119,7 +119,7 @@ object KinesisWordCountASL extends Logging {
val batchInterval = Milliseconds(2000)
// Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information
- // on sequence number of records that have been received. Same as batchInterval for this
+ // on sequence number of records that have been received. Same as batchInterval for this
// example.
val kinesisCheckpointInterval = batchInterval
@@ -145,7 +145,7 @@ object KinesisWordCountASL extends Logging {
// Map each word to a (word, 1) tuple so we can reduce by key to count the words
val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _)
-
+
// Print the first 10 wordCounts
wordCounts.print()
@@ -210,14 +210,14 @@ object KinesisWordProducerASL {
val randomWords = List("spark", "you", "are", "my", "father")
val totals = scala.collection.mutable.Map[String, Int]()
-
+
// Create the low-level Kinesis Client from the AWS Java SDK.
val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain())
kinesisClient.setEndpoint(endpoint)
println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" +
s" $recordsPerSecond records per second and $wordsPerRecord words per record")
-
+
// Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord
for (i <- 1 to 10) {
// Generate recordsPerSec records to put onto the stream
@@ -255,8 +255,8 @@ object KinesisWordProducerASL {
}
}
-/**
- * Utility functions for Spark Streaming examples.
+/**
+ * Utility functions for Spark Streaming examples.
* This has been lifted from the examples/ project to remove the circular dependency.
*/
private[streaming] object StreamingExamples extends Logging {
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
index 1c9b0c218ae1..83a453755951 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
@@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock}
/**
* This is a helper class for managing checkpoint clocks.
*
- * @param checkpointInterval
+ * @param checkpointInterval
* @param currentClock. Default to current SystemClock if none is passed in (mocking purposes)
*/
private[kinesis] class KinesisCheckpointState(
- checkpointInterval: Duration,
+ checkpointInterval: Duration,
currentClock: Clock = new SystemClock())
extends Logging {
-
+
/* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */
val checkpointClock = new ManualClock()
checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds)
/**
- * Check if it's time to checkpoint based on the current time and the derived time
+ * Check if it's time to checkpoint based on the current time and the derived time
* for the next checkpoint
*
* @return true if it's time to checkpoint
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 7dd8bfdc2a6d..1a8a4cecc114 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -44,12 +44,12 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
* https://github.com/awslabs/amazon-kinesis-client
* This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here:
* http://spark.apache.org/docs/latest/streaming-custom-receivers.html
- * Instances of this class will get shipped to the Spark Streaming Workers to run within a
+ * Instances of this class will get shipped to the Spark Streaming Workers to run within a
* Spark Executor.
*
* @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams
* by the Kinesis Client Library. If you change the App name or Stream name,
- * the KCL will throw errors. This usually requires deleting the backing
+ * the KCL will throw errors. This usually requires deleting the backing
* DynamoDB table with the same name this Kinesis application.
* @param streamName Kinesis stream name
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
@@ -87,7 +87,7 @@ private[kinesis] class KinesisReceiver(
*/
/**
- * workerId is used by the KCL should be based on the ip address of the actual Spark Worker
+ * workerId is used by the KCL should be based on the ip address of the actual Spark Worker
* where this code runs (not the driver's IP address.)
*/
private var workerId: String = null
@@ -121,7 +121,7 @@ private[kinesis] class KinesisReceiver(
/*
* RecordProcessorFactory creates impls of IRecordProcessor.
- * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the
+ * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the
* IRecordProcessor.processRecords() method.
* We're using our custom KinesisRecordProcessor in this case.
*/
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index f65e743c4e2a..fe9e3a0c793e 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -35,9 +35,9 @@ import com.amazonaws.services.kinesis.model.Record
/**
* Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor.
* This implementation operates on the Array[Byte] from the KinesisReceiver.
- * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each
- * shard in the Kinesis stream upon startup. This is normally done in separate threads,
- * but the KCLs within the KinesisReceivers will balance themselves out if you create
+ * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each
+ * shard in the Kinesis stream upon startup. This is normally done in separate threads,
+ * but the KCLs within the KinesisReceivers will balance themselves out if you create
* multiple Receivers.
*
* @param receiver Kinesis receiver
@@ -69,14 +69,14 @@ private[kinesis] class KinesisRecordProcessor(
* and Spark Streaming's Receiver.store().
*
* @param batch list of records from the Kinesis stream shard
- * @param checkpointer used to update Kinesis when this batch has been processed/stored
+ * @param checkpointer used to update Kinesis when this batch has been processed/stored
* in the DStream
*/
override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) {
if (!receiver.isStopped()) {
try {
/*
- * Notes:
+ * Notes:
* 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming
* Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the
* internally-configured Spark serializer (kryo, etc).
@@ -84,19 +84,19 @@ private[kinesis] class KinesisRecordProcessor(
* ourselves from Spark's internal serialization strategy.
* 3) For performance, the BlockGenerator is asynchronously queuing elements within its
* memory before creating blocks. This prevents the small block scenario, but requires
- * that you register callbacks to know when a block has been generated and stored
+ * that you register callbacks to know when a block has been generated and stored
* (WAL is sufficient for storage) before can checkpoint back to the source.
*/
batch.foreach(record => receiver.store(record.getData().array()))
-
+
logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId")
/*
- * Checkpoint the sequence number of the last record successfully processed/stored
+ * Checkpoint the sequence number of the last record successfully processed/stored
* in the batch.
* In this implementation, we're checkpointing after the given checkpointIntervalMillis.
- * Note that this logic requires that processRecords() be called AND that it's time to
- * checkpoint. I point this out because there is no background thread running the
+ * Note that this logic requires that processRecords() be called AND that it's time to
+ * checkpoint. I point this out because there is no background thread running the
* checkpointer. Checkpointing is tested and trigger only when a new batch comes in.
* If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below).
* However, if the worker dies unexpectedly, a checkpoint may not happen.
@@ -130,16 +130,16 @@ private[kinesis] class KinesisRecordProcessor(
}
} else {
/* RecordProcessor has been stopped. */
- logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" +
+ logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" +
s" and shardId $shardId. No more records will be processed.")
}
}
/**
* Kinesis Client Library is shutting down this Worker for 1 of 2 reasons:
- * 1) the stream is resharding by splitting or merging adjacent shards
+ * 1) the stream is resharding by splitting or merging adjacent shards
* (ShutdownReason.TERMINATE)
- * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason
+ * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason
* (ShutdownReason.ZOMBIE)
*
* @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE
@@ -153,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor(
* Checkpoint to indicate that all records from the shard have been drained and processed.
* It's now OK to read from the new shards that resulted from a resharding event.
*/
- case ShutdownReason.TERMINATE =>
+ case ShutdownReason.TERMINATE =>
KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
/*
diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml
index e14bbae4a9b6..478d0019a25f 100644
--- a/extras/spark-ganglia-lgpl/pom.xml
+++ b/extras/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index d38a3aa8256b..853dea9a7795 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -40,6 +40,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ com.google.guavaguava
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
index cc70b396a8dd..4611a3ace219 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -41,14 +41,16 @@ abstract class EdgeRDD[ED](
@transient sc: SparkContext,
@transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) {
+ // scalastyle:off structural.type
private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD }
+ // scalastyle:on structural.type
override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = {
val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context)
if (p.hasNext) {
- p.next._2.iterator.map(_.copy())
+ p.next()._2.iterator.map(_.copy())
} else {
Iterator.empty
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
index eb1dbe52c2fd..f1ecc9e2219d 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
-class EdgeRDDSuite extends FunSuite with LocalSparkContext {
+class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext {
test("cache, getStorageLevel") {
// test to see if getStorageLevel returns correct value after caching
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
index 5a2c73b41427..094a63472eaa 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
@@ -17,21 +17,21 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class EdgeSuite extends FunSuite {
+class EdgeSuite extends SparkFunSuite {
test ("compare") {
// decending order
val testEdges: Array[Edge[Int]] = Array(
- Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1),
- Edge(0x2345L, 0x1234L, 1),
- Edge(0x1234L, 0x5678L, 1),
- Edge(0x1234L, 0x2345L, 1),
+ Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1),
+ Edge(0x2345L, 0x1234L, 1),
+ Edge(0x1234L, 0x5678L, 1),
+ Edge(0x1234L, 0x2345L, 1),
Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1)
)
// to ascending order
val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int])
-
+
for (i <- 0 until testEdges.length) {
assert(sortedEdges(i) == testEdges(testEdges.length - i - 1))
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index 68fe83739e39..57a8b95dd12e 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.graphx
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.impl.EdgePartition
import org.apache.spark.rdd._
-import org.scalatest.FunSuite
-class GraphOpsSuite extends FunSuite with LocalSparkContext {
+class GraphOpsSuite extends SparkFunSuite with LocalSparkContext {
test("joinVertices") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 2b1d8e47326f..1f5e27d5508b 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.PartitionStrategy._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class GraphSuite extends FunSuite with LocalSparkContext {
+class GraphSuite extends SparkFunSuite with LocalSparkContext {
def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = {
Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v")
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
index 490b94429ea1..8afa2d403b53 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.rdd._
-class PregelSuite extends FunSuite with LocalSparkContext {
+class PregelSuite extends SparkFunSuite with LocalSparkContext {
test("1 iteration") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
index d0a7198d691d..f1aa685a79c9 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
@@ -17,13 +17,11 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.{HashPartitioner, SparkContext}
+import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-class VertexRDDSuite extends FunSuite with LocalSparkContext {
+class VertexRDDSuite extends SparkFunSuite with LocalSparkContext {
private def vertices(sc: SparkContext, n: Int) = {
VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5))
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 515f3a9cd02e..7435647c6d9e 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl
import scala.reflect.ClassTag
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.graphx._
-class EdgePartitionSuite extends FunSuite {
+class EdgePartitionSuite extends SparkFunSuite {
def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = {
val builder = new EdgePartitionBuilder[A, Int]
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index fe8304c1cdc3..1203f8959f50 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -17,15 +17,13 @@
package org.apache.spark.graphx.impl
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.graphx._
-class VertexPartitionSuite extends FunSuite {
+class VertexPartitionSuite extends SparkFunSuite {
test("isDefined, filter") {
val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
index accccfc232cd..c965a6eb8df1 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext {
test("Grid Connected Components") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
index 61fd0c460556..808877f0590f 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
-class LabelPropagationSuite extends FunSuite with LocalSparkContext {
+class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext {
test("Label Propagation") {
withSpark { sc =>
// Construct a graph with two cliques connected by a single edge
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
index 39c6ace912b0..45f1e3011035 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
@@ -57,7 +56,7 @@ object GridPageRank {
}
-class PageRankSuite extends FunSuite with LocalSparkContext {
+class PageRankSuite extends SparkFunSuite with LocalSparkContext {
def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = {
a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
index 7bd6b7f3c4ab..2991438f5e57 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
-class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
+class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext {
test("Test SVD++ with mean square error on training set") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
index f2c38e79c452..d7eaa70ce640 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.lib._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class ShortestPathsSuite extends FunSuite with LocalSparkContext {
+class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext {
test("Shortest Path Computations") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
index 1f658c371ffc..d6b03208180d 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext {
test("Island Strongly Connected Components") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
index 79bf4e6cd18e..c47552cf3a3b 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut
-class TriangleCountSuite extends FunSuite with LocalSparkContext {
+class TriangleCountSuite extends SparkFunSuite with LocalSparkContext {
test("Count a single triangle") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
index f3b3738db0da..186d0cc2a977 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.graphx.util
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class BytecodeUtilsSuite extends FunSuite {
+class BytecodeUtilsSuite extends SparkFunSuite {
import BytecodeUtilsSuite.TestClass
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
index 8d9c8ddccbb3..32e0c841c699 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx.util
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.LocalSparkContext
-class GraphGeneratorsSuite extends FunSuite with LocalSparkContext {
+class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext {
test("GraphGenerators.generateRandomEdges") {
val src = 5
diff --git a/launcher/pom.xml b/launcher/pom.xml
index ebfa7685eaa1..48dd0d5f9106 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -22,14 +22,14 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xmlorg.apache.sparkspark-launcher_2.10jar
- Spark Launcher Project
+ Spark Project Launcherhttp://spark.apache.org/launcher
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index 33fd813f7a86..33d65d13f0d2 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -296,6 +296,9 @@ Properties loadPropertiesFile() throws IOException {
try {
fd = new FileInputStream(propsFile);
props.load(new InputStreamReader(fd, "UTF-8"));
+ for (Map.Entry
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-streaming_${scala.binary.version}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 11a4722722ea..a9bd28df71ee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -17,6 +17,9 @@
package org.apache.spark.ml
+import java.{util => ju}
+
+import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
@@ -175,6 +178,11 @@ class PipelineModel private[ml] (
val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {
+ /** A Java/Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
+ this(uid, stages.asScala.toArray)
+ }
+
override def validateParams(): Unit = {
super.validateParams()
stages.foreach(_.validateParams())
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index d8592eb2d947..62f4b51f770e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -208,7 +208,7 @@ private[ml] object GBTClassificationModel {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index d13109d9da4c..f136bcee9cf2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -74,7 +74,7 @@ class LogisticRegression(override val uid: String)
setDefault(elasticNetParam -> 0.0)
/**
- * Set the maximal number of iterations.
+ * Set the maximum number of iterations.
* Default is 100.
* @group setParam
*/
@@ -90,7 +90,11 @@ class LogisticRegression(override val uid: String)
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
- /** @group setParam */
+ /**
+ * Whether to fit an intercept term.
+ * Default is true.
+ * @group setParam
+ * */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index b8c7f3c5bc3b..825f9ed1b54b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -37,11 +37,13 @@ import org.apache.spark.storage.StorageLevel
*/
private[ml] trait OneVsRestParams extends PredictorParams {
+ // scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
type M <: ClassificationModel[F, M]
type E <: Classifier[F, E, M]
}
+ // scalastyle:on structural.type
/**
* param for the base binary classifier that we reduce multiclass classification into.
@@ -129,6 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ .drop(accColName)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 67600ebd7b38..852a67e06632 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -170,7 +170,7 @@ private[ml] object RandomForestClassificationModel {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 3ae183339015..1e758cb775de 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -41,7 +41,7 @@ class ElementwiseProduct(override val uid: String)
* the vector to multiply with input vectors
* @group param
*/
- val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product")
+ val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product")
/** @group setParam */
def setScalingVec(value: Vector): this.type = set(scalingVec, value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index fdd2494fc87a..b0fd06d84fdb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -35,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
/**
* Centers the data with mean before scaling.
- * It will build a dense output, so this does not work on sparse input
+ * It will build a dense output, so this does not work on sparse input
* and will raise an exception.
* Default: false
* @group param
*/
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
-
+
/**
* Scales the data to unit standard deviation.
* Default: true
@@ -68,13 +68,13 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
-
+
/** @group setParam */
def setWithMean(value: Boolean): this.type = set(withMean, value)
-
+
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
-
+
override def fit(dataset: DataFrame): StandardScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index a2dc8a8b960c..f4e250757560 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -88,6 +88,9 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
/**
* :: Experimental ::
* Model fitted by [[StringIndexer]].
+ * NOTE: During transformation, if the input column does not exist,
+ * [[StringIndexerModel.transform]] would return the input dataset unmodified.
+ * This is a temporary fix for the case when target labels do not exist during prediction.
*/
@Experimental
class StringIndexerModel private[ml] (
@@ -112,6 +115,12 @@ class StringIndexerModel private[ml] (
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = {
+ if (!dataset.schema.fieldNames.contains($(inputCol))) {
+ logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
+ "Skip StringIndexerModel.")
+ return dataset
+ }
+
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
@@ -128,6 +137,11 @@ class StringIndexerModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ if (schema.fieldNames.contains($(inputCol))) {
+ validateAndTransformSchema(schema)
+ } else {
+ // If the input column does not exist during transformation, we skip StringIndexerModel.
+ schema
+ }
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 473488dce9b0..ba94d6a3a80a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
}
}
- /**
- * Creates a param pair with the given value (for Java).
- */
+ /** Creates a param pair with the given value (for Java). */
def w(value: T): ParamPair[T] = this -> value
- /**
- * Creates a param pair with the given value (for Scala).
- */
+ /** Creates a param pair with the given value (for Scala). */
def ->(value: T): ParamPair[T] = ParamPair(this, value)
override final def toString: String = s"${parent}__$name"
@@ -190,6 +186,7 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Double): ParamPair[Double] = super.w(value)
}
@@ -209,6 +206,7 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Int): ParamPair[Int] = super.w(value)
}
@@ -228,6 +226,7 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Float): ParamPair[Float] = super.w(value)
}
@@ -247,6 +246,7 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Long): ParamPair[Long] = super.w(value)
}
@@ -260,6 +260,7 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
@@ -274,8 +275,6 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
- override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
-
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}
@@ -291,10 +290,9 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
- override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value)
-
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
- def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray)
+ def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
+ w(value.asScala.map(_.asInstanceOf[Double]).toArray)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 1ffb5eddc36b..8ffbcf0d8bc7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen {
val params = Seq(
ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
isValid = "ParamValidators.gtEq(0)"),
- ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
+ ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)",
isValid = "ParamValidators.gtEq(0)"),
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index ed08417bd4df..a0c8ccdac9ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params {
private[ml] trait HasMaxIter extends Params {
/**
- * Param for max number of iterations (>= 0).
+ * Param for maximum number of iterations (>= 0).
* @group param
*/
- final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
+ final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
final def getMaxIter: Int = $(maxIter)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 69f4f5414c8c..b7e374bb6cb4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -198,7 +198,7 @@ private[ml] object GBTRegressionModel {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 7c40db1a4004..70cd8e9e87fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -83,7 +83,7 @@ class LinearRegression(override val uid: String)
setDefault(elasticNetParam -> 0.0)
/**
- * Set the maximal number of iterations.
+ * Set the maximum number of iterations.
* Default is 100.
* @group setParam
*/
@@ -321,7 +321,7 @@ private class LeastSquaresAggregator(
}
(weightsArray, -sum + labelMean / labelStd, weightsArray.length)
}
-
+
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
private val gradientSumArray = Array.ofDim[Double](dim)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index ae767a17329d..49a1f7ce8c99 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -152,7 +152,7 @@ private[ml] object RandomForestRegressionModel {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
new RandomForestRegressionModel(parent.uid, newTrees)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 6434b64aed15..cb29392e8bc6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
- copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
+ copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
- val bestModel: Model[_])
+ val bestModel: Model[_],
+ val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams {
override def validateParams(): Unit = {
@@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] (
}
override def copy(extra: ParamMap): CrossValidatorModel = {
- val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
+ val copied = new CrossValidatorModel(
+ uid,
+ bestModel.copy(extra).asInstanceOf[Model[_]],
+ avgMetrics.clone())
copyValues(copied, extra)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 65f30fdba739..8f66bc808a00 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -43,7 +43,8 @@ import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.stat.test.ChiSqTestResult
-import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
+import org.apache.spark.mllib.stat.{
+ KernelDensity, MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.loss.Losses
@@ -399,7 +400,7 @@ private[python] class PythonMLLibAPI extends Serializable {
val sigma = si.map(_.asInstanceOf[DenseMatrix])
val gaussians = Array.tabulate(weight.length){
i => new MultivariateGaussian(mean(i), sigma(i))
- }
+ }
val model = new GaussianMixtureModel(weight, gaussians)
model.predictSoft(data).map(Vectors.dense)
}
@@ -494,7 +495,7 @@ private[python] class PythonMLLibAPI extends Serializable {
def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = {
new Normalizer(p).transform(rdd)
}
-
+
/**
* Java stub for StandardScaler.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
@@ -945,6 +946,15 @@ private[python] class PythonMLLibAPI extends Serializable {
r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any])))
}
+ /**
+ * Java stub for the estimate method of KernelDensity
+ */
+ def estimateKernelDensity(
+ sample: JavaRDD[Double],
+ bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = {
+ return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
+ points.asScala.toArray)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index e9a23e40cc79..fc509d2ba147 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq
import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
@@ -36,11 +37,11 @@ import org.apache.spark.util.Utils
* independent Gaussian distributions with associated "mixing" weights
* specifying each's contribution to the composite.
*
- * Given a set of sample points, this class will maximize the log-likelihood
- * for a mixture of k Gaussians, iterating until the log-likelihood changes by
+ * Given a set of sample points, this class will maximize the log-likelihood
+ * for a mixture of k Gaussians, iterating until the log-likelihood changes by
* less than convergenceTol, or until it has reached the max number of iterations.
* While this process is generally guaranteed to converge, it is not guaranteed
- * to find a global optimum.
+ * to find a global optimum.
*
* Note: For high-dimensional data (with many features), this algorithm may perform poorly.
* This is due to high-dimensional data (a) making it difficult to cluster at all (based
@@ -53,24 +54,24 @@ import org.apache.spark.util.Utils
*/
@Experimental
class GaussianMixture private (
- private var k: Int,
- private var convergenceTol: Double,
+ private var k: Int,
+ private var convergenceTol: Double,
private var maxIterations: Int,
private var seed: Long) extends Serializable {
-
+
/**
* Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01,
* maxIterations: 100, seed: random}.
*/
def this() = this(2, 0.01, 100, Utils.random.nextLong())
-
+
// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5
-
- // an initializing GMM can be provided rather than using the
+
+ // an initializing GMM can be provided rather than using the
// default random starting point
private var initialModel: Option[GaussianMixtureModel] = None
-
+
/** Set the initial GMM starting point, bypassing the random initialization.
* You must call setK() prior to calling this method, and the condition
* (model.k == this.k) must be met; failure will result in an IllegalArgumentException
@@ -83,37 +84,37 @@ class GaussianMixture private (
}
this
}
-
+
/** Return the user supplied initial GMM, if supplied */
def getInitialModel: Option[GaussianMixtureModel] = initialModel
-
+
/** Set the number of Gaussians in the mixture model. Default: 2 */
def setK(k: Int): this.type = {
this.k = k
this
}
-
+
/** Return the number of Gaussians in the mixture model */
def getK: Int = k
-
+
/** Set the maximum number of iterations to run. Default: 100 */
def setMaxIterations(maxIterations: Int): this.type = {
this.maxIterations = maxIterations
this
}
-
+
/** Return the maximum number of iterations to run */
def getMaxIterations: Int = maxIterations
-
+
/**
- * Set the largest change in log-likelihood at which convergence is
+ * Set the largest change in log-likelihood at which convergence is
* considered to have occurred.
*/
def setConvergenceTol(convergenceTol: Double): this.type = {
this.convergenceTol = convergenceTol
this
}
-
+
/**
* Return the largest change in log-likelihood at which convergence is
* considered to have occurred.
@@ -132,41 +133,41 @@ class GaussianMixture private (
/** Perform expectation maximization */
def run(data: RDD[Vector]): GaussianMixtureModel = {
val sc = data.sparkContext
-
+
// we will operate on the data as breeze data
val breezeData = data.map(_.toBreeze).cache()
-
+
// Get length of the input vectors
val d = breezeData.first().length
-
+
// Determine initial weights and corresponding Gaussians.
// If the user supplied an initial GMM, we use those values, otherwise
// we start with uniform weights, a random mean from the data, and
// diagonal covariance matrices using component variances
- // derived from the samples
+ // derived from the samples
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weights, gmm.gaussians)
-
+
case None => {
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
- (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
+ (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
- new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
+ new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
})
}
}
-
- var llh = Double.MinValue // current log-likelihood
+
+ var llh = Double.MinValue // current log-likelihood
var llhp = 0.0 // previous log-likelihood
-
+
var iter = 0
while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) {
// create and broadcast curried cluster contribution function
val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
-
+
// aggregate the cluster contribution for all sample points
val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
-
+
// Create new distributions based on the partial assignments
// (often referred to as the "M" step in literature)
val sumWeights = sums.weights.sum
@@ -179,22 +180,25 @@ class GaussianMixture private (
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
i = i + 1
}
-
+
llhp = llh // current becomes previous
llh = sums.logLikelihood // this is the freshly computed log-likelihood
iter += 1
- }
-
+ }
+
new GaussianMixtureModel(weights, gaussians)
}
-
+
+ /** Java-friendly version of [[run()]] */
+ def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)
+
/** Average of dense breeze vectors */
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
val v = BDV.zeros[Double](x(0).length)
x.foreach(xi => v += xi)
- v / x.length.toDouble
+ v / x.length.toDouble
}
-
+
/**
* Construct matrix where diagonal entries are element-wise
* variance of input vectors (computes biased variance)
@@ -210,14 +214,14 @@ class GaussianMixture private (
// companion class to provide zero constructor for ExpectationSum
private object ExpectationSum {
def zero(k: Int, d: Int): ExpectationSum = {
- new ExpectationSum(0.0, Array.fill(k)(0.0),
+ new ExpectationSum(0.0, Array.fill(k)(0.0),
Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d)))
}
-
+
// compute cluster contributions for each input point
// (U, T) => U for aggregation
def add(
- weights: Array[Double],
+ weights: Array[Double],
dists: Array[MultivariateGaussian])
(sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
val p = weights.zip(dists).map {
@@ -235,7 +239,7 @@ private object ExpectationSum {
i = i + 1
}
sums
- }
+ }
}
// Aggregation class for partial expectation results
@@ -244,9 +248,9 @@ private class ExpectationSum(
val weights: Array[Double],
val means: Array[BDV[Double]],
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
-
+
val k = weights.length
-
+
def +=(x: ExpectationSum): ExpectationSum = {
var i = 0
while (i < k) {
@@ -257,5 +261,5 @@ private class ExpectationSum(
}
logLikelihood += x.logLikelihood
this
- }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 86353aed8115..cb807c803810 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
@@ -34,10 +35,10 @@ import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
*
- * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
- * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
- * the respective mean and covariance for each Gaussian distribution i=1..k.
- *
+ * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
+ * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
+ * the respective mean and covariance for each Gaussian distribution i=1..k.
+ *
* @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is
* the weight for Gaussian i, and weights.sum == 1
* @param gaussians Array of MultivariateGaussian where gaussians(i) represents
@@ -45,9 +46,9 @@ import org.apache.spark.sql.{SQLContext, Row}
*/
@Experimental
class GaussianMixtureModel(
- val weights: Array[Double],
- val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
-
+ val weights: Array[Double],
+ val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {
+
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
override protected def formatVersion = "1.0"
@@ -64,20 +65,24 @@ class GaussianMixtureModel(
val responsibilityMatrix = predictSoft(points)
responsibilityMatrix.map(r => r.indexOf(r.max))
}
-
+
+ /** Java-friendly version of [[predict()]] */
+ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
+ predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
+
/**
* Given the input vectors, return the membership value of each vector
- * to all mixture components.
+ * to all mixture components.
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
val bcDists = sc.broadcast(gaussians)
val bcWeights = sc.broadcast(weights)
- points.map { x =>
+ points.map { x =>
computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
}
}
-
+
/**
* Compute the partial assignments for each vector
*/
@@ -89,7 +94,7 @@ class GaussianMixtureModel(
val p = weights.zip(dists).map {
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt)
}
- val pSum = p.sum
+ val pSum = p.sum
for (i <- 0 until k) {
p(i) /= pSum
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 6cf26445f20a..974b26924dfb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
import org.apache.spark.rdd.RDD
@@ -345,6 +346,11 @@ class DistributedLDAModel private (
}
}
+ /** Java-friendly version of [[topicDistributions]] */
+ def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
+ JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
+ }
+
// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 1ed01c9d8ba0..e7a243f854e3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] (
import org.apache.spark.mllib.clustering.PowerIterationClustering._
/** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100,
- * initMode: "random"}.
+ * initMode: "random"}.
*/
def this() = this(k = 2, maxIterations = 100, initMode = "random")
@@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging {
/**
* Generates random vertex properties (v0) to start power iteration.
- *
+ *
* @param g a graph representing the normalized affinity matrix (W)
* @return a graph with edges representing W and vertices representing a random vector
* with unit 1-norm
@@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging {
* Generates the degree vector as the vertex properties (v0) to start power iteration.
* It is not exactly the node degrees but just the normalized sum similarities. Call it
* as degree vector because it is used in the PIC paper.
- *
+ *
* @param g a graph representing the normalized affinity matrix (W)
* @return a graph with edges representing W and vertices representing the degree vector
*/
@@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging {
val v0 = g.vertices.mapValues(_ / sum)
GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges)
}
-
+
/**
* Runs power iteration.
* @param g input graph with edges representing the normalized affinity matrix (W) and vertices
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 812014a04171..d9b34cec6489 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -21,8 +21,10 @@ import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaSparkContext._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -178,7 +180,7 @@ class StreamingKMeans(
/** Set the decay factor directly (for forgetful algorithms). */
def setDecayFactor(a: Double): this.type = {
- this.decayFactor = decayFactor
+ this.decayFactor = a
this
}
@@ -234,6 +236,9 @@ class StreamingKMeans(
}
}
+ /** Java-friendly version of `trainOn`. */
+ def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream)
+
/**
* Use the clustering model to make predictions on batches of data from a DStream.
*
@@ -245,6 +250,11 @@ class StreamingKMeans(
data.map(model.predict)
}
+ /** Java-friendly version of `predictOn`. */
+ def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = {
+ JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]])
+ }
+
/**
* Use the model to make predictions on the values of a DStream and carry over its keys.
*
@@ -257,6 +267,14 @@ class StreamingKMeans(
data.mapValues(model.predict)
}
+ /** Java-friendly version of `predictOnValues`. */
+ def predictOnValues[K](
+ data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = {
+ implicit val tag = fakeClassTag[K]
+ JavaPairDStream.fromPairDStream(
+ predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]])
+ }
+
/** Check whether cluster centers have been initialized. */
private[this] def assertInitialized(): Unit = {
if (model.clusterCenters == null) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 9cc2d0ffcab7..5f8c1dea237b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -108,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf
* (ordered by statistic value descending)
*/
@Experimental
-class ChiSqSelector (val numTopFeatures: Int) {
+class ChiSqSelector (val numTopFeatures: Int) extends Serializable {
/**
* Returns a ChiSquared feature selector.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
index b0985baf9b27..d67fe6c3ee4f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
@@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._
* Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a
* provided "weight" vector. In other words, it scales each column of the dataset by a scalar
* multiplier.
- * @param scalingVector The values used to scale the reference vector's individual components.
+ * @param scalingVec The values used to scale the reference vector's individual components.
*/
@Experimental
-class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
+class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer {
/**
* Does the hadamard product transformation.
@@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
* @return transformed vector.
*/
override def transform(vector: Vector): Vector = {
- require(vector.size == scalingVector.size,
- s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}")
+ require(vector.size == scalingVec.size,
+ s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}")
vector match {
case dv: DenseVector =>
val values: Array[Double] = dv.values.clone()
- val dim = scalingVector.size
+ val dim = scalingVec.size
var i = 0
while (i < dim) {
- values(i) *= scalingVector(i)
+ values(i) *= scalingVec(i)
i += 1
}
Vectors.dense(values)
@@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
val dim = values.length
var i = 0
while (i < dim) {
- values(i) *= scalingVector(indices(i))
+ values(i) *= scalingVec(indices(i))
i += 1
}
Vectors.sparse(size, indices, values)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 466ae95859b8..51546d41c36a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -42,7 +42,7 @@ import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.sql.{SQLContext, Row}
/**
- * Entry in vocabulary
+ * Entry in vocabulary
*/
private case class VocabWord(
var word: String,
@@ -56,18 +56,18 @@ private case class VocabWord(
* :: Experimental ::
* Word2Vec creates vector representation of words in a text corpus.
* The algorithm first constructs a vocabulary from the corpus
- * and then learns vector representation of words in the vocabulary.
- * The vector representation can be used as features in
+ * and then learns vector representation of words in the vocabulary.
+ * The vector representation can be used as features in
* natural language processing and machine learning algorithms.
- *
- * We used skip-gram model in our implementation and hierarchical softmax
+ *
+ * We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* matches the original C implementation.
*
- * For original C implementation, see https://code.google.com/p/word2vec/
- * For research papers, see
+ * For original C implementation, see https://code.google.com/p/word2vec/
+ * For research papers, see
* Efficient Estimation of Word Representations in Vector Space
- * and
+ * and
* Distributed Representations of Words and Phrases and their Compositionality.
*/
@Experimental
@@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging {
private var numIterations = 1
private var seed = Utils.random.nextLong()
private var minCount = 5
-
+
/**
* Sets vector size (default: 100).
*/
@@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging {
this
}
- /**
- * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
+ /**
+ * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
* model's vocabulary (default: 5).
*/
def setMinCount(minCount: Int): this.type = {
this.minCount = minCount
this
}
-
+
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
@@ -150,13 +150,13 @@ class Word2Vec extends Serializable with Logging {
.map(x => VocabWord(
x._1,
x._2,
- new Array[Int](MAX_CODE_LENGTH),
- new Array[Int](MAX_CODE_LENGTH),
+ new Array[Int](MAX_CODE_LENGTH),
+ new Array[Int](MAX_CODE_LENGTH),
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b) => a.cn > b.cn)
-
+
vocabSize = vocab.length
require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +
"the setting of minCount, which could be large enough to remove all your words in sentences.")
@@ -198,8 +198,8 @@ class Word2Vec extends Serializable with Logging {
}
var pos1 = vocabSize - 1
var pos2 = vocabSize
-
- var min1i = 0
+
+ var min1i = 0
var min2i = 0
a = 0
@@ -268,15 +268,15 @@ class Word2Vec extends Serializable with Logging {
val words = dataset.flatMap(x => x)
learnVocab(words)
-
+
createBinaryTree()
-
+
val sc = dataset.context
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
-
+
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext
@@ -297,7 +297,7 @@ class Word2Vec extends Serializable with Logging {
}
}
}
-
+
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
@@ -402,7 +402,7 @@ class Word2Vec extends Serializable with Logging {
}
}
newSentences.unpersist()
-
+
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
@@ -480,7 +480,7 @@ class Word2VecModel private[mllib] (
/**
* Transforms a word to its vector representation
- * @param word a word
+ * @param word a word
* @return vector representation of word
*/
def transform(word: String): Vector = {
@@ -495,7 +495,7 @@ class Word2VecModel private[mllib] (
/**
* Find synonyms of a word
* @param word a word
- * @param num number of synonyms to find
+ * @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
@@ -506,7 +506,7 @@ class Word2VecModel private[mllib] (
/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
- * @param num number of synonyms to find
+ * @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index ec38529cf8fa..3523f1804325 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -213,9 +213,9 @@ private[spark] object BLAS extends Serializable with Logging {
def scal(a: Double, x: Vector): Unit = {
x match {
case sx: SparseVector =>
- f2jBLAS.dscal(sx.values.size, a, sx.values, 1)
+ f2jBLAS.dscal(sx.values.length, a, sx.values, 1)
case dx: DenseVector =>
- f2jBLAS.dscal(dx.values.size, a, dx.values, 1)
+ f2jBLAS.dscal(dx.values.length, a, dx.values, 1)
case _ =>
throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.")
}
@@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
_nativeBLAS
}
-
+
/**
* A := alpha * x * x^T^ + A
* @param alpha a real scalar that will be multiplied to x * x^T^.
@@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging {
j += 1
}
i += 1
- }
+ }
}
private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
@@ -505,7 +505,7 @@ private[spark] object BLAS extends Serializable with Logging {
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
y.values, 1)
}
-
+
/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A and `SparseVector` x.
@@ -557,7 +557,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
}
-
+
/**
* y := alpha * A * x + beta * y
* For `SparseMatrix` A and `SparseVector` x.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
index 866936aa4f11..ae3ba3099c87 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
@@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition {
require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE,
s"k = $k and/or n = $n are too large to compute an eigendecomposition")
-
+
var ido = new intW(0)
var info = new intW(0)
var resid = new Array[Double](n)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 9584da8e3a0f..85e63b1382b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -197,6 +197,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
override def typeName: String = "matrix"
+ override def pyUDT: String = "pyspark.mllib.linalg.MatrixUDT"
+
private[spark] override def asNullable: MatrixUDT = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 4b7d0589c973..06e45e10c5bf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -179,7 +179,7 @@ object GradientDescent extends Logging {
* if it's L2 updater; for L1 updater, the same logic is followed.
*/
var regVal = updater.compute(
- weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
+ weights, Vectors.zeros(weights.size), 0, 1, regParam)._2
for (i <- 1 to numIterations) {
val bcWeights = data.context.broadcast(weights)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
index 34b447584e52..622b53a252ac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
@@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
*/
private[mllib] class BinaryClassificationPMMLModelExport(
- model : GeneralizedLinearModel,
+ model : GeneralizedLinearModel,
description : String,
normalizationMethod : RegressionNormalizationMethodType,
- threshold: Double)
+ threshold: Double)
extends PMMLModelExport {
populateBinaryClassificationPMML()
@@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
.withUsageType(FieldUsageType.ACTIVE))
regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}
-
+
// add target field
val targetField = FieldName.create("target")
dataDictionary
@@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport(
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))
-
+
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
-
+
pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
index ebdeae50bb32..c5fdecd3ca17 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
@@ -25,7 +25,7 @@ import scala.beans.BeanProperty
import org.dmg.pmml.{Application, Header, PMML, Timestamp}
private[mllib] trait PMMLModelExport {
-
+
/**
* Holder of the exported model in PMML format
*/
@@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport {
val pmml: PMML = new PMML
setHeader(pmml)
-
+
private def setHeader(pmml: PMML): Unit = {
val version = getClass.getPackage.getImplementationVersion
val app = new Application().withName("Apache Spark MLlib").withVersion(version)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
index c16e83d6a067..29bd689e1185 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
@@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.RidgeRegressionModel
private[mllib] object PMMLModelExportFactory {
-
+
/**
- * Factory object to help creating the necessary PMMLModelExport implementation
+ * Factory object to help creating the necessary PMMLModelExport implementation
* taking as input the machine learning model (for example KMeansModel).
*/
def createPMMLModelExport(model: Any): PMMLModelExport = {
@@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory {
new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
case svm: SVMModel =>
new BinaryClassificationPMMLModelExport(
- svm, "linear SVM", RegressionNormalizationMethodType.NONE,
+ svm, "linear SVM", RegressionNormalizationMethodType.NONE,
svm.getThreshold.getOrElse(0.0))
case logistic: LogisticRegressionModel =>
if (logistic.numClasses == 2) {
@@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory {
"PMML Export not supported for model: " + model.getClass.getName)
}
}
-
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
index 7db5a14fd45a..174d5e0f6c9f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
@@ -234,7 +234,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param shape shape parameter (> 0) for the gamma distribution
- * @param scale scale parameter (> 0) for the gamma distribution
+ * @param scale scale parameter (> 0) for the gamma distribution
* @param size Size of the RDD.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
* @param seed Random seed (default: a random long integer).
@@ -293,7 +293,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param mean mean for the log normal distribution
- * @param std standard deviation for the log normal distribution
+ * @param std standard deviation for the log normal distribution
* @param size Size of the RDD.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
* @param seed Random seed (default: a random long integer).
@@ -671,7 +671,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param shape shape parameter (> 0) for the gamma distribution.
- * @param scale scale parameter (> 0) for the gamma distribution.
+ * @param scale scale parameter (> 0) for the gamma distribution.
* @param numRows Number of Vectors in the RDD.
* @param numCols Number of elements in each Vector.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index dddefe1944e9..93290e650852 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -175,7 +175,7 @@ class ALS private (
/**
* :: DeveloperApi ::
* Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default
- * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g.
+ * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g.
* `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement,
* at the cost of speed.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 26be30ff9d6f..6709bd79bc82 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
val initialWeights = {
if (numOfLinearPredictor == 1) {
- Vectors.dense(new Array[Double](numFeatures))
+ Vectors.zeros(numFeatures)
} else if (addIntercept) {
- Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor))
+ Vectors.zeros((numFeatures + 1) * numOfLinearPredictor)
} else {
- Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor))
+ Vectors.zeros(numFeatures * numOfLinearPredictor)
}
}
run(input, initialWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
index 96e50faca2b1..f3b46c75c05f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -170,15 +170,15 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
case class Data(boundary: Double, prediction: Double)
def save(
- sc: SparkContext,
- path: String,
- boundaries: Array[Double],
- predictions: Array[Double],
+ sc: SparkContext,
+ path: String,
+ boundaries: Array[Double],
+ predictions: Array[Double],
isotonic: Boolean): Unit = {
val sqlContext = new SQLContext(sc)
val metadata = compact(render(
- ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("isotonic" -> isotonic)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index e0c03d8180c7..7d28ffad45c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -73,7 +73,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
/**
* Train a regression model with L2-regularization using Stochastic Gradient Descent.
- * This solves the l1-regularized least squares regression formulation
+ * This solves the l2-regularized least squares regression formulation
* f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index cea8f3f47307..141052ba813e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -83,21 +83,15 @@ abstract class StreamingLinearAlgorithm[
throw new IllegalArgumentException("Model must be initialized before starting training.")
}
data.foreachRDD { (rdd, time) =>
- val initialWeights =
- model match {
- case Some(m) =>
- m.weights
- case None =>
- val numFeatures = rdd.first().features.size
- Vectors.dense(numFeatures)
+ if (!rdd.isEmpty) {
+ model = Some(algorithm.run(rdd, model.get.weights))
+ logInfo(s"Model updated at time ${time.toString}")
+ val display = model.get.weights.size match {
+ case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
+ case _ => model.get.weights.toArray.mkString("[", ",", "]")
}
- model = Some(algorithm.run(rdd, initialWeights))
- logInfo("Model updated at time %s".format(time.toString))
- val display = model.get.weights.size match {
- case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
- case _ => model.get.weights.toArray.mkString("[", ",", "]")
+ logInfo(s"Current model: weights, ${display}")
}
- logInfo("Current model: weights, %s".format (display))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index a49153bf73c0..235e043c7754 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] (
this
}
- /** Set the initial weights. Default: [0.0, 0.0]. */
+ /** Set the initial weights. */
def setInitialWeights(initialWeights: Vector): this.type = {
this.model = Some(algorithm.createModel(initialWeights, 0.0))
this
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
index b3fad0c52d65..900007ec6bc7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.stat
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -80,6 +81,10 @@ object Statistics {
*/
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
+ /** Java-friendly version of [[corr()]] */
+ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
+ corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])
+
/**
* Compute the correlation for the input RDDs using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
@@ -96,6 +101,10 @@ object Statistics {
*/
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
+ /** Java-friendly version of [[corr()]] */
+ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
+ corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)
+
/**
* Conduct Pearson's chi-squared goodness of fit test of the observed data against the
* expected distribution.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index cd6add9d60b0..cf51b24ff777 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils
* the event that the covariance matrix is singular, the density will be computed in a
* reduced dimensional subspace under which the distribution is supported.
* (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
- *
+ *
* @param mu The mean vector of the distribution
* @param sigma The covariance matrix of the distribution
*/
@DeveloperApi
class MultivariateGaussian (
- val mu: Vector,
+ val mu: Vector,
val sigma: Matrix) extends Serializable {
require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
-
+
private val breezeMu = mu.toBreeze.toDenseVector
-
+
/**
* private[mllib] constructor
- *
+ *
* @param mu The mean vector of the distribution
* @param sigma The covariance matrix of the distribution
*/
private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = {
this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma))
}
-
+
/**
* Compute distribution dependent constants:
* rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t
- * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
*/
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
-
+
/** Returns density of this multivariate Gaussian at given point, x */
def pdf(x: Vector): Double = {
pdf(x.toBreeze)
}
-
+
/** Returns the log-density of this multivariate Gaussian at given point, x */
def logpdf(x: Vector): Double = {
logpdf(x.toBreeze)
}
-
+
/** Returns density of this multivariate Gaussian at given point, x */
private[mllib] def pdf(x: BV[Double]): Double = {
math.exp(logpdf(x))
}
-
+
/** Returns the log-density of this multivariate Gaussian at given point, x */
private[mllib] def logpdf(x: BV[Double]): Double = {
val delta = x - breezeMu
val v = rootSigmaInv * delta
u + v.t * v * -0.5
}
-
+
/**
* Calculate distribution dependent components used for the density function:
* pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
* where k is length of the mean vector.
- *
- * We here compute distribution-fixed parts
+ *
+ * We here compute distribution-fixed parts
* log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
* and
* D^(-1/2)^ * U, where sigma = U * D * U.t
- *
+ *
* Both the determinant and the inverse can be computed from the singular value decomposition
* of sigma. Noting that covariance matrices are always symmetric and positive semi-definite,
* we can use the eigendecomposition. We also do not compute the inverse directly; noting
- * that
- *
+ * that
+ *
* sigma = U * D * U.t
- * inv(Sigma) = U * inv(D) * U.t
+ * inv(Sigma) = U * inv(D) * U.t
* = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U)
- *
+ *
* and thus
- *
+ *
* -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^
- *
- * To guard against singular covariance matrices, this method computes both the
+ *
+ * To guard against singular covariance matrices, this method computes both the
* pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
* to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
* relation to the maximum singular value (same tolerance used by, e.g., Octave).
*/
private def calculateCovarianceConstants: (DBM[Double], Double) = {
val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
-
+
// For numerical stability, values are considered to be non-zero only if they exceed tol.
// This prevents any inverted value from exceeding (eps * n * max(d))^-1
val tol = MLUtils.EPSILON * max(d) * d.length
-
+
try {
// log(pseudo-determinant) is sum of the logs of all non-zero singular values
val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
-
- // calculate the root-pseudo-inverse of the diagonal matrix of singular values
+
+ // calculate the root-pseudo-inverse of the diagonal matrix of singular values
// by inverting the square root of all non-zero values
val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
-
+
(pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
} catch {
case uex: UnsupportedOperationException =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index e3ddc7053693..a835f96d5d0e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging {
logInfo(s"$timer")
if (persistedInput) input.unpersist()
-
+
if (validate) {
new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 99d0e3cf2fd6..069959976a18 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -474,7 +474,7 @@ object RandomForest extends Serializable with Logging {
val (treeIndex, node) = nodeQueue.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
} else {
None
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index ee710fc1ed29..a6d1398fc267 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -83,7 +83,7 @@ class Node (
def predict(features: Vector) : Double = {
if (isLeaf) {
predict.predict
- } else{
+ } else {
if (split.get.featureType == Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
leftNode.get.predict(features)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 681f4c618d30..7c5cfa7bd84c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -82,6 +82,18 @@ object MLUtils {
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
+
+ // check if indices are one-based and in ascending order
+ var previous = -1
+ var i = 0
+ val indicesLength = indices.length
+ while (i < indicesLength) {
+ val current = indices(i)
+ require(current > previous, "indices should be one-based and in ascending order" )
+ previous = current
+ i += 1
+ }
+
(label, indices.toArray, values.toArray)
}
@@ -258,14 +270,30 @@ object MLUtils {
* Returns a new vector with `1.0` (bias) appended to the input vector.
*/
def appendBias(vector: Vector): Vector = {
- val vector1 = vector.toBreeze match {
- case dv: BDV[Double] => BDV.vertcat(dv, new BDV[Double](Array(1.0)))
- case sv: BSV[Double] => BSV.vertcat(sv, new BSV[Double](Array(0), Array(1.0), 1))
- case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
+ vector match {
+ case dv: DenseVector =>
+ val inputValues = dv.values
+ val inputLength = inputValues.length
+ val outputValues = Array.ofDim[Double](inputLength + 1)
+ System.arraycopy(inputValues, 0, outputValues, 0, inputLength)
+ outputValues(inputLength) = 1.0
+ Vectors.dense(outputValues)
+ case sv: SparseVector =>
+ val inputValues = sv.values
+ val inputIndices = sv.indices
+ val inputValuesLength = inputValues.length
+ val dim = sv.size
+ val outputValues = Array.ofDim[Double](inputValuesLength + 1)
+ val outputIndices = Array.ofDim[Int](inputValuesLength + 1)
+ System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength)
+ System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength)
+ outputValues(inputValuesLength) = 1.0
+ outputIndices(inputValuesLength) = dim
+ Vectors.sparse(dim + 1, outputIndices, outputValues)
+ case _ => throw new IllegalArgumentException(s"Do not support vector type ${vector.getClass}")
}
- Vectors.fromBreeze(vector1)
}
-
+
/**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index da2218056307..599e9cfd23ad 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -55,9 +55,9 @@ public void tearDown() {
@Test
public void hashingTF() {
JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
- RowFactory.create(0, "Hi I heard about Spark"),
- RowFactory.create(0, "I wish Java could use case classes"),
- RowFactory.create(1, "Logistic regression models are neat")
+ RowFactory.create(0.0, "Hi I heard about Spark"),
+ RowFactory.create(0.0, "I wish Java could use case classes"),
+ RowFactory.create(1.0, "Logistic regression models are neat")
));
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
new file mode 100644
index 000000000000..35b18c5308f6
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.Arrays;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import static org.apache.spark.sql.types.DataTypes.*;
+
+public class JavaStringIndexerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext sqlContext;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
+ sqlContext = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ sqlContext = null;
+ }
+
+ @Test
+ public void testStringIndexer() {
+ StructType schema = createStructType(new StructField[] {
+ createStructField("id", IntegerType, false),
+ createStructField("label", StringType, false)
+ });
+ JavaRDD rdd = jsc.parallelize(
+ Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
+ DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+
+ StringIndexer indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex");
+ DataFrame output = indexer.fit(dataset).transform(dataset);
+
+ Assert.assertArrayEquals(
+ new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) },
+ output.orderBy("id").select("id", "labelIndex").collect());
+ }
+
+ /** An alias for RowFactory.create. */
+ private Row c(Object... values) {
+ return RowFactory.create(values);
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
new file mode 100644
index 000000000000..b7c564caad3b
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.Arrays;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.*;
+import static org.apache.spark.sql.types.DataTypes.*;
+
+public class JavaVectorAssemblerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext sqlContext;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
+ sqlContext = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void testVectorAssembler() {
+ StructType schema = createStructType(new StructField[] {
+ createStructField("id", IntegerType, false),
+ createStructField("x", DoubleType, false),
+ createStructField("y", new VectorUDT(), false),
+ createStructField("name", StringType, false),
+ createStructField("z", new VectorUDT(), false),
+ createStructField("n", LongType, false)
+ });
+ Row row = RowFactory.create(
+ 0, 0.0, Vectors.dense(1.0, 2.0), "a",
+ Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
+ JavaRDD rdd = jsc.parallelize(Arrays.asList(row));
+ DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+ VectorAssembler assembler = new VectorAssembler()
+ .setInputCols(new String[] {"x", "y", "z", "n"})
+ .setOutputCol("features");
+ DataFrame output = assembler.transform(dataset);
+ Assert.assertEquals(
+ Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
+ output.select("features").first().getAs(0));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
index e7df10dfa63a..9890155e9f86 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -50,6 +50,7 @@ public void testParams() {
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
Assert.assertEquals(testParams.getMyStringParam(), "a");
+ Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 947ae3a2ce06..ff5929235ac2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -51,7 +51,8 @@ public String uid() {
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
public JavaTestParams setMyIntParam(int value) {
- set(myIntParam_, value); return this;
+ set(myIntParam_, value);
+ return this;
}
private DoubleParam myDoubleParam_;
@@ -60,7 +61,8 @@ public JavaTestParams setMyIntParam(int value) {
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
public JavaTestParams setMyDoubleParam(double value) {
- set(myDoubleParam_, value); return this;
+ set(myDoubleParam_, value);
+ return this;
}
private Param myStringParam_;
@@ -69,7 +71,18 @@ public JavaTestParams setMyDoubleParam(double value) {
public String getMyStringParam() { return getOrDefault(myStringParam_); }
public JavaTestParams setMyStringParam(String value) {
- set(myStringParam_, value); return this;
+ set(myStringParam_, value);
+ return this;
+ }
+
+ private DoubleArrayParam myDoubleArrayParam_;
+ public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
+
+ public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
+
+ public JavaTestParams setMyDoubleArrayParam(double[] value) {
+ set(myDoubleArrayParam_, value);
+ return this;
}
private void init() {
@@ -79,8 +92,14 @@ private void init() {
List validStrings = Lists.newArrayList("a", "b");
myStringParam_ = new Param(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
- setDefault(myIntParam_, 1);
- setDefault(myDoubleParam_, 0.5);
+ myDoubleArrayParam_ =
+ new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
+
+ setDefault(myIntParam(), 1);
+ setDefault(myIntParam().w(1));
+ setDefault(myDoubleParam(), 0.5);
setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
+ setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
+ setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
index 67c262d0f9d8..928301523fba 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
+++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.util
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class IdentifiableSuite extends FunSuite {
+class IdentifiableSuite extends SparkFunSuite {
import IdentifiableSuite.Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
similarity index 95%
rename from mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
rename to mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
index 640d2ec55e4e..55787f8606d4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.ml.classification;
+package org.apache.spark.mllib.classification;
import java.io.Serializable;
import java.util.List;
@@ -28,7 +28,6 @@
import org.junit.Test;
import org.apache.spark.SparkConf;
-import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
new file mode 100644
index 000000000000..467a7a69e8f3
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+public class JavaGaussianMixtureSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaGaussianMixture");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runGaussianMixture() {
+ List points = Lists.newArrayList(
+ Vectors.dense(1.0, 2.0, 6.0),
+ Vectors.dense(1.0, 3.0, 0.0),
+ Vectors.dense(1.0, 4.0, 6.0)
+ );
+
+ JavaRDD data = sc.parallelize(points, 2);
+ GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
+ .run(data);
+ assertEquals(model.gaussians().length, 2);
+ JavaRDD predictions = model.predict(data);
+ predictions.first();
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index 96c2da169961..581c033f08eb 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -107,6 +107,10 @@ public void distributedLDAModel() {
// Check: log probabilities
assert(model.logLikelihood() < 0.0);
assert(model.logPrior() < 0.0);
+
+ // Check: topic distributions
+ JavaPairRDD topicDistributions = model.javaTopicDistributions();
+ assertEquals(topicDistributions.count(), corpus.count());
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
new file mode 100644
index 000000000000..3b0e879eec77
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.apache.spark.streaming.JavaTestUtils.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+public class JavaStreamingKMeansSuite implements Serializable {
+
+ protected transient JavaStreamingContext ssc;
+
+ @Before
+ public void setUp() {
+ SparkConf conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("test")
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
+ ssc = new JavaStreamingContext(conf, new Duration(1000));
+ ssc.checkpoint("checkpoint");
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void javaAPI() {
+ List trainingBatch = Lists.newArrayList(
+ Vectors.dense(1.0),
+ Vectors.dense(0.0));
+ JavaDStream training =
+ attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
+ List> testBatch = Lists.newArrayList(
+ new Tuple2(10, Vectors.dense(1.0)),
+ new Tuple2(11, Vectors.dense(0.0)));
+ JavaPairDStream test = JavaPairDStream.fromJavaDStream(
+ attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
+ StreamingKMeans skmeans = new StreamingKMeans()
+ .setK(1)
+ .setDecayFactor(1.0)
+ .setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0});
+ skmeans.trainOn(training);
+ JavaPairDStream prediction = skmeans.predictOnValues(test);
+ attachTestOutputStream(prediction.count());
+ runStreams(ssc, 2, 2);
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
new file mode 100644
index 000000000000..62f7f26b7c98
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat;
+
+import java.io.Serializable;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+public class JavaStatisticsSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaStatistics");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void testCorr() {
+ JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0));
+ JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3));
+
+ Double corr1 = Statistics.corr(x, y);
+ Double corr2 = Statistics.corr(x, y, "pearson");
+ // Check default method
+ assertEquals(corr1, corr2);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 2b04a3034782..29394fefcbc4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,15 +17,17 @@
package org.apache.spark.ml
+import scala.collection.JavaConverters._
+
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
-import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar.mock
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
-class PipelineSuite extends FunSuite {
+class PipelineSuite extends SparkFunSuite {
abstract class MyModel extends Model[MyModel]
@@ -81,4 +83,19 @@ class PipelineSuite extends FunSuite {
pipeline.fit(dataset)
}
}
+
+ test("pipeline model constructors") {
+ val transform0 = mock[Transformer]
+ val model1 = mock[MyModel]
+
+ val stages = Array(transform0, model1)
+ val pipelineModel0 = new PipelineModel("pipeline0", stages)
+ assert(pipelineModel0.uid === "pipeline0")
+ assert(pipelineModel0.stages === stages)
+
+ val stagesAsList = stages.toList.asJava
+ val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList)
+ assert(pipelineModel1.uid === "pipeline1")
+ assert(pipelineModel1.stages === stages)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 17ddd335deb6..512cffb1acb6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.attribute
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class AttributeGroupSuite extends FunSuite {
+class AttributeGroupSuite extends SparkFunSuite {
test("attribute group") {
val attrs = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index ec9b717e41ce..72b575d02254 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.ml.attribute
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
-class AttributeSuite extends FunSuite {
+class AttributeSuite extends SparkFunSuite {
test("default numeric attribute") {
val attr: NumericAttribute = NumericAttribute.defaultAttr
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 3fdc66be8a31..ae40b0b8ff85 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
@@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeClassifierSuite.compareAPIs
@@ -251,7 +250,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
*/
}
-private[ml] object DecisionTreeClassifierSuite extends FunSuite {
+private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
/**
* Train 2 decision trees on the given dataset, one using the old API and one using the new API.
@@ -266,7 +265,7 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite {
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index ea86867f1161..1302da3c373f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[GBTClassifier]].
*/
-class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import GBTClassifierSuite.compareAPIs
@@ -128,7 +127,7 @@ private object GBTClassifierSuite {
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val newModel = gbt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 9f77d5f3efc5..a755cac3ea76 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@transient var binaryDataset: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 770b56890fa4..1d04ccb50905 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -30,7 +29,7 @@ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
@@ -94,6 +93,15 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}
+
+ test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+ val logReg = new LogisticRegression()
+ .setMaxIter(1)
+ val ovr = new OneVsRest()
+ .setClassifier(logReg)
+ val output = ovr.fit(dataset).transform(dataset)
+ assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index cdbbacab8e0e..eee9355a67be 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
@@ -32,7 +31,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestClassifierSuite.compareAPIs
@@ -158,7 +157,7 @@ private object RandomForestClassifierSuite {
data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val newModel = rf.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 3ea7aad5274f..36a1ac6b7996 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.ml.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
+class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Regression Evaluator: default params") {
/**
@@ -39,7 +38,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
val dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
-
+
/**
* Using the following R code to load the data, train the model and evaluate metrics.
*
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 8f6c6b39dc93..7953bd041719 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
+class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var data: Array[Double] = _
@@ -48,7 +47,7 @@ class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
test("Binarize continuous features with setter") {
val threshold: Double = 0.2
- val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
+ val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
data.zip(thresholdBinarized)).toDF("feature", "expected")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 0391bd8427c2..507a8a7db24c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -19,15 +19,13 @@ package org.apache.spark.ml.feature
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
+class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
@@ -110,7 +108,7 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
}
}
-private object BucketizerSuite extends FunSuite {
+private object BucketizerSuite extends SparkFunSuite {
/** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
require(feature >= splits.head)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 2e4beb0bfff6..7b2d70e64400 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -26,7 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
+class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
val hashingTF = new HashingTF
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index f85e85471617..d83772e8be75 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class IDFSuite extends FunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index 9d09f24709e2..9f03470b7f32 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
+class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var data: Array[Vector] = _
@transient var dataFrame: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 9018d0024d5f..2e5036a84456 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
-class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
+class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index aa230ca073d5..feca866cd711 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -17,15 +17,15 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
import org.scalatest.exceptions.TestFailedException
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext {
+class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Polynomial expansion with default parameter") {
val data = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 89c2fe45573a..5f557e16e515 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
+class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
@@ -61,4 +60,12 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
+
+ test("StringIndexerModel should keep silent if the input column does not exist.") {
+ val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ val df = sqlContext.range(0L, 10L)
+ assert(indexerModel.transform(df).eq(df))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index eabda089d098..ac279cb3215c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -19,15 +19,14 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
-class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
+class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
test("RegexTokenizer") {
@@ -60,7 +59,7 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object RegexTokenizerSuite extends FunSuite {
+object RegexTokenizerSuite extends SparkFunSuite {
def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 43534e89928b..489abb5af713 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
-class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
+class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index b11b029c6343..06affc7305cf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature
import scala.beans.{BeanInfo, BeanProperty}
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
+class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
import VectorIndexerSuite.FeatureData
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index df446d0c2201..94ebc3aebfa3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
-class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
+class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Word2Vec") {
val sqlContext = new SQLContext(sc)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 1505ad872536..778abcba22c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -19,8 +19,7 @@ package org.apache.spark.ml.impl
import scala.collection.JavaConverters._
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.tree._
@@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame}
-private[ml] object TreeTests extends FunSuite {
+private[ml] object TreeTests extends SparkFunSuite {
/**
* Convert the given data to a DataFrame, and set the features and label metadata.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 04f2af4727ea..96094d7a099a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.param
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class ParamsSuite extends FunSuite {
+class ParamsSuite extends SparkFunSuite {
test("param") {
val solver = new TestParams()
@@ -27,7 +27,7 @@ class ParamsSuite extends FunSuite {
import solver.{maxIter, inputCol}
assert(maxIter.name === "maxIter")
- assert(maxIter.doc === "max number of iterations (>= 0)")
+ assert(maxIter.doc === "maximum number of iterations (>= 0)")
assert(maxIter.parent === uid)
assert(maxIter.toString === s"${uid}__maxIter")
assert(!maxIter.isValid(-1))
@@ -36,7 +36,7 @@ class ParamsSuite extends FunSuite {
solver.setMaxIter(5)
assert(solver.explainParam(maxIter) ===
- "maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
+ "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)")
assert(inputCol.toString === s"${uid}__inputCol")
@@ -120,7 +120,7 @@ class ParamsSuite extends FunSuite {
intercept[NoSuchElementException](solver.getInputCol)
assert(solver.explainParam(maxIter) ===
- "maxIter: max number of iterations (>= 0) (default: 10, current: 100)")
+ "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
assert(solver.explainParams() ===
Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
@@ -202,7 +202,7 @@ class ParamsSuite extends FunSuite {
}
}
-object ParamsSuite extends FunSuite {
+object ParamsSuite extends SparkFunSuite {
/**
* Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
index ca18fa1ad3c1..eb5408d3fee7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.ml.param.shared
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.Params
-class SharedParamsSuite extends FunSuite {
+class SharedParamsSuite extends SparkFunSuite {
test("outputCol") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 9a35555e52b9..2e5cfe7027eb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer
import scala.language.existentials
import com.github.fommil.netlib.BLAS.{getInstance => blas}
-import org.scalatest.FunSuite
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.recommendation.ALS._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -36,7 +35,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.Utils
-class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
+class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
private var tempDir: File = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 1196a772dfdd..33aa9d0d6234 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
@@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeRegressorSuite.compareAPIs
@@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
// TODO: test("model save/load") SPARK-6725
}
-private[ml] object DecisionTreeRegressorSuite extends FunSuite {
+private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
/**
* Train 2 decision trees on the given dataset, one using the old API and one using the new API.
@@ -83,7 +82,7 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite {
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newTree = dt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 40e7e3273e96..98fb3d3f5f22 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[GBTRegressor]].
*/
-class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import GBTRegressorSuite.compareAPIs
@@ -129,7 +128,7 @@ private object GBTRegressorSuite {
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTRegressionModel.fromOld(
oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 50a78631fa6d..732e2c42be14 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 3efffbb763b7..b24ecaa57c89 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestRegressorSuite.compareAPIs
@@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
*/
}
-private object RandomForestRegressorSuite extends FunSuite {
+private object RandomForestRegressorSuite extends SparkFunSuite {
/**
* Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -114,7 +113,7 @@ private object RandomForestRegressorSuite extends FunSuite {
data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = rf.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestRegressionModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 60d8bfe38fb1..9b3619f0046e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.tuning
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
-class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
+class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@@ -56,6 +56,7 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
+ assert(cvModel.avgMetrics.length === lrParamMaps.length)
}
test("validateParams should check estimatorParamMaps") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
index 20aa100112bf..810b70049ec1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
@@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.{ParamMap, TestParams}
-class ParamGridBuilderSuite extends FunSuite {
+class ParamGridBuilderSuite extends SparkFunSuite {
val solver = new TestParams()
import solver.{inputCol, maxIter}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index 3d362b5ee53e..59944416d96a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.api.python
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating
-class PythonMLLibAPISuite extends FunSuite {
+class PythonMLLibAPISuite extends SparkFunSuite {
SerDe.initialize()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 966811a5a326..e8f3d0c4db20 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -21,9 +21,9 @@ import scala.collection.JavaConversions._
import scala.util.Random
import scala.util.control.Breaks._
-import org.scalatest.FunSuite
import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -119,7 +119,7 @@ object LogisticRegressionSuite {
}
// Preventing the overflow when we compute the probability
val maxMargin = margins.max
- if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin
+ if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin
// Computing the probabilities for each class from the margins.
val norm = {
@@ -130,7 +130,7 @@ object LogisticRegressionSuite {
}
temp
}
- for (i <-0 until nClasses) probs(i) /= norm
+ for (i <- 0 until nClasses) probs(i) /= norm
// Compute the cumulative probability so we can generate a random number and assign a label.
for (i <- 1 until nClasses) probs(i) += probs(i - 1)
@@ -169,7 +169,7 @@ object LogisticRegressionSuite {
}
-class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
def validatePrediction(
predictions: Seq[Double],
input: Seq[LabeledPoint],
@@ -541,7 +541,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
}
-class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction using SGD optimizer") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index ea40b41bbbe5..f7fc8730606a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -21,9 +21,8 @@ import scala.util.Random
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -86,7 +85,7 @@ object NaiveBayesSuite {
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
}
-class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
+class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
import NaiveBayes.{Multinomial, Bernoulli}
@@ -286,7 +285,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
+class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 10
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 90f9cec6855b..b1d78cba9e3d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -21,9 +21,8 @@ import scala.collection.JavaConversions._
import scala.util.Random
import org.jblas.DoubleMatrix
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -62,7 +61,7 @@ object SVMSuite {
}
-class SVMSuite extends FunSuite with MLlibTestSparkContext {
+class SVMSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
+class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index 5683b55e8500..fd653296c9d9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.TestSuiteBase
-class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
+class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 30000
@@ -159,4 +158,21 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
assert(error.head > 0.8 & error.last < 0.2)
}
+
+ // Test empty RDDs in a stream
+ test("handling empty RDDs in a stream") {
+ val model = new StreamingLogisticRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(-0.1))
+ .setStepSize(0.01)
+ .setNumIterations(10)
+ val numBatches = 10
+ val emptyInput = Seq.empty[Seq[LabeledPoint]]
+ val ssc = setupStreams(emptyInput,
+ (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ }
+ )
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index f356ffa3e3a2..b218d72f1268 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.clustering
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
+class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
test("single cluster") {
val data = sc.parallelize(Array(
Vectors.dense(6.0, 9.0),
@@ -47,7 +46,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
}
-
+
test("two clusters") {
val data = sc.parallelize(GaussianTestData.data)
@@ -63,7 +62,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
-
+
val gmm = new GaussianMixture()
.setK(2)
.setInitialModel(initialGmm)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 877e6dc69952..0dbbd7127444 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class KMeansSuite extends FunSuite with MLlibTestSparkContext {
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
@@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object KMeansSuite extends FunSuite {
+object KMeansSuite extends SparkFunSuite {
def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
val singlePoint = isSparse match {
case true =>
@@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite {
}
}
-class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
+class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index d5b7d9633574..406affa25539 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM}
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class LDASuite extends FunSuite with MLlibTestSparkContext {
+class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
import LDASuite._
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 556842f3129a..19e65f1b53ab 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
+class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.mllib.clustering.PowerIterationClustering._
@@ -58,7 +56,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
predictions(a.cluster) += a.id
}
assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
-
+
val model2 = new PowerIterationClustering()
.setK(2)
.setInitializationMode("degree")
@@ -130,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
}
}
-object PowerIterationClusteringSuite extends FunSuite {
+object PowerIterationClusteringSuite extends SparkFunSuite {
def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
val assignments = sc.parallelize(
(0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
index f90025d535e4..ac01622b8a08 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.clustering
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.TestSuiteBase
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.random.XORShiftRandom
-class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
+class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
override def maxWaitTimeMillis: Int = 30000
@@ -133,6 +132,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
assert(math.abs(c1) ~== 0.8 absTol 0.6)
}
+ test("SPARK-7946 setDecayFactor") {
+ val kMeans = new StreamingKMeans()
+ assert(kMeans.decayFactor === 1.0)
+ kMeans.setDecayFactor(2.0)
+ assert(kMeans.decayFactor === 2.0)
+ }
+
def StreamingKMeansDataGenerator(
numPoints: Int,
numBatches: Int,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
index 79847633ff0d..87ccc7eda44e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext {
+class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext {
test("auc computation") {
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
val auc = 4.0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index e0224f960cc4..99d52fabc530 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index 7dc4f3cfbc4e..d55bc8c3ec09 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Matrices
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Multiclass evaluation metrics") {
/*
* Confusion matrix for 3-class classification with total 9 instances:
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
index 2537dd62c92f..f3b19aeb42f8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Multilabel evaluation metrics") {
/*
* Documents true labels (5x class0, 3x class1, 4x class2):
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
index 609eed983ff4..c0924a213a84 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Ranking metrics: map, ndcg") {
val predictionAndLabels = sc.parallelize(
Seq(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index 3aa732474ec2..9de2bdb6d724 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("regression metrics") {
val predictionAndObservations = sc.parallelize(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index 747f5914598e..889727fb5582 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext {
+class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
/*
* Contingency tables
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
index f3a482abda87..ccbf8a91cdd3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext {
+class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext {
test("elementwise (hadamard) product should properly apply vector to dense data set") {
val denseData = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index 0c4dfb7b97c7..cf279c02334e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
+class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("hashing tf on a single doc") {
val hashingTF = new HashingTF(1000)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
index 0a5cad7caf8e..21163633051e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class IDFSuite extends FunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("idf") {
val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
index 5c4af2b99e68..34122d6ed2e9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
import breeze.linalg.{norm => brzNorm}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
+class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
index 758af588f1c6..e57f49191378 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class PCASuite extends FunSuite with MLlibTestSparkContext {
+class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
private val data = Array(
Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index 1eb991869de4..6ab2fa677012 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
import org.apache.spark.rdd.RDD
-class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
+class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
// When the input data is all constant, the variance is zero. The standardization against
// zero variance is not well-defined, but we decide to just set it into zero here.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 98a98a7599bc..b6818369208d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
+class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
// TODO: add more tests
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index bd5b9cc3afa1..66ae3543ecc4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -16,11 +16,10 @@
*/
package org.apache.spark.mllib.fpm
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
+class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
test("FP-Growth using String type") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
index 04017f67c311..a56d7b357921 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
@@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm
import scala.language.existentials
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class FPTreeSuite extends FunSuite with MLlibTestSparkContext {
+class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
test("add transaction") {
val tree = new FPTree[String]
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index 699f009f0f2e..d34888af2d73 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -17,18 +17,16 @@
package org.apache.spark.mllib.impl
-import org.scalatest.FunSuite
-
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext {
+class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
import PeriodicGraphCheckpointerSuite._
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 64ecd12ea7de..b0f3f71113c5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.linalg.BLAS._
-class BLASSuite extends FunSuite {
+class BLASSuite extends SparkFunSuite {
test("copy") {
val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0))
@@ -140,7 +139,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dA)
assert(dA ~== expected absTol 1e-15)
-
+
val dB =
new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))
@@ -149,7 +148,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dB)
}
}
-
+
val dC =
new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))
@@ -158,7 +157,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dC)
}
}
-
+
val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))
withClue("Size of vector must match the rank of matrix") {
@@ -256,13 +255,13 @@ class BLASSuite extends FunSuite {
val dA =
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
-
+
val dA2 =
new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true)
val sA2 =
new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0),
true)
-
+
val dx = new DenseVector(Array(1.0, 2.0, 3.0))
val sx = dx.toSparse
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
@@ -271,7 +270,7 @@ class BLASSuite extends FunSuite {
assert(sA.multiply(dx) ~== expected absTol 1e-15)
assert(dA.multiply(sx) ~== expected absTol 1e-15)
assert(sA.multiply(sx) ~== expected absTol 1e-15)
-
+
val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
val y3 = y1.copy
@@ -288,7 +287,7 @@ class BLASSuite extends FunSuite {
val y14 = y1.copy
val y15 = y1.copy
val y16 = y1.copy
-
+
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
@@ -296,42 +295,42 @@ class BLASSuite extends FunSuite {
gemv(1.0, sA, dx, 2.0, y2)
gemv(1.0, dA, sx, 2.0, y3)
gemv(1.0, sA, sx, 2.0, y4)
-
+
gemv(1.0, dA2, dx, 2.0, y5)
gemv(1.0, sA2, dx, 2.0, y6)
gemv(1.0, dA2, sx, 2.0, y7)
gemv(1.0, sA2, sx, 2.0, y8)
-
+
gemv(2.0, dA, dx, 2.0, y9)
gemv(2.0, sA, dx, 2.0, y10)
gemv(2.0, dA, sx, 2.0, y11)
gemv(2.0, sA, sx, 2.0, y12)
-
+
gemv(2.0, dA2, dx, 2.0, y13)
gemv(2.0, sA2, dx, 2.0, y14)
gemv(2.0, dA2, sx, 2.0, y15)
gemv(2.0, sA2, sx, 2.0, y16)
-
+
assert(y1 ~== expected2 absTol 1e-15)
assert(y2 ~== expected2 absTol 1e-15)
assert(y3 ~== expected2 absTol 1e-15)
assert(y4 ~== expected2 absTol 1e-15)
-
+
assert(y5 ~== expected2 absTol 1e-15)
assert(y6 ~== expected2 absTol 1e-15)
assert(y7 ~== expected2 absTol 1e-15)
assert(y8 ~== expected2 absTol 1e-15)
-
+
assert(y9 ~== expected3 absTol 1e-15)
assert(y10 ~== expected3 absTol 1e-15)
assert(y11 ~== expected3 absTol 1e-15)
assert(y12 ~== expected3 absTol 1e-15)
-
+
assert(y13 ~== expected3 absTol 1e-15)
assert(y14 ~== expected3 absTol 1e-15)
assert(y15 ~== expected3 absTol 1e-15)
assert(y16 ~== expected3 absTol 1e-15)
-
+
withClue("columns of A don't match the rows of B") {
intercept[Exception] {
gemv(1.0, dA.transpose, dx, 2.0, y1)
@@ -346,12 +345,12 @@ class BLASSuite extends FunSuite {
gemv(1.0, sA.transpose, sx, 2.0, y1)
}
}
-
+
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
-
+
val dATT = dAT.transpose
val sATT = sAT.transpose
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
index 203103237397..dc04258e41d2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM}
-class BreezeMatrixConversionSuite extends FunSuite {
+import org.apache.spark.SparkFunSuite
+
+class BreezeMatrixConversionSuite extends SparkFunSuite {
test("dense matrix to breeze") {
val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
val breeze = mat.toBreeze.asInstanceOf[BDM[Double]]
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
index 8abdac72902c..3772c9235ad3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+import org.apache.spark.SparkFunSuite
+
/**
* Test Breeze vector conversions.
*/
-class BreezeVectorConversionSuite extends FunSuite {
+class BreezeVectorConversionSuite extends SparkFunSuite {
val arr = Array(0.1, 0.2, 0.3, 0.4)
val n = 20
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 86119ec38101..8dbb70f5d1c4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg
import java.util.Random
import org.mockito.Mockito.when
-import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar._
import scala.collection.mutable.{Map => MutableMap}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
-class MatricesSuite extends FunSuite {
+class MatricesSuite extends SparkFunSuite {
test("dense matrix construction") {
val m = 3
val n = 2
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 24755e9ff46f..c4ae0a16f7c0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg
import scala.util.Random
import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.util.TestingUtils._
-class VectorsSuite extends FunSuite {
+class VectorsSuite extends SparkFunSuite {
val arr = Array(0.1, 0.0, 0.3, 0.4)
val n = 4
@@ -215,13 +214,13 @@ class VectorsSuite extends FunSuite {
val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze)
- // SparseVector vs. SparseVector
- assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
+ // SparseVector vs. SparseVector
+ assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
// DenseVector vs. SparseVector
assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
// DenseVector vs. DenseVector
assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8)
- }
+ }
}
test("foreachActive") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index a58336175899..93fe04c139b9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed
import java.{util => ju}
import breeze.linalg.{DenseMatrix => BDM}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 5
val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
index 04b36a9ef999..f3728cd036a3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.linalg.Vectors
-class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 5
val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 2ab53cc13db7..4a7b99a976f0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
-
import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vectors}
-class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 4
val n = 3
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index 27bb19f472e1..b6cb53d0c743 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed
import scala.util.Random
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
-class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 4
val n = 3
@@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext {
+class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
var mat: RowMatrix = _
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index e110506d579b..a5a59e9fad5a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -20,8 +20,9 @@ package org.apache.spark.mllib.optimization
import scala.collection.JavaConversions._
import scala.util.Random
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -61,7 +62,7 @@ object GradientDescentSuite {
}
}
-class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
test("Assert the loss is decreasing.") {
val nPoints = 10000
@@ -140,7 +141,7 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc
}
}
-class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext {
+class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index c8f2adcf155a..d07b9d5b8922 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
val nPoints = 10000
val A = 2.0
@@ -229,7 +230,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
}
}
-class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small") {
val m = 10
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
index bb723fc47118..d8f9b8c33963 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
-import org.scalatest.FunSuite
-
import org.jblas.{DoubleMatrix, SimpleBlas}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
-class NNLSSuite extends FunSuite {
+class NNLSSuite extends SparkFunSuite {
/** Generate an NNLS problem whose optimal solution is the all-ones vector. */
def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = {
val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
index 0b646cf1ce6c..4c6e76e47419 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.RegressionModel
import org.dmg.pmml.RegressionNormalizationMethodType
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.util.LinearDataGenerator
-class BinaryClassificationPMMLModelExportSuite extends FunSuite {
+class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite {
test("logistic regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
@@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite {
// ensure logistic regression has normalization method set to LOGIT
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
}
-
+
test("linear SVM PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
-
+
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
-
+
// assert that the PMML format is as expected
assert(svmModelExport.isInstanceOf[PMMLModelExport])
val pmml = svmModelExport.getPmml
@@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite {
// ensure linear SVM has normalization method set to NONE
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
}
-
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
index f9afbd888dfc..1d3230948178 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.RegressionModel
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator
-class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
+class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite {
test("linear regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
index b985d0446d7b..b3f9750afa73 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.ClusteringModel
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
-class KMeansPMMLModelExportSuite extends FunSuite {
+class KMeansPMMLModelExportSuite extends SparkFunSuite {
test("KMeansPMMLModelExport generate PMML format") {
val clusterCenters = Array(
@@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite {
val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
}
-
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
index f28a4ac8ad01..af4945096175 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.pmml.export
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator
-class PMMLModelExportFactorySuite extends FunSuite {
+class PMMLModelExportFactorySuite extends SparkFunSuite {
test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
val clusterCenters = Array(
@@ -61,25 +60,25 @@ class PMMLModelExportFactorySuite extends FunSuite {
test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
+ "when passing a LogisticRegressionModel or SVMModel") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
-
+
val logisticRegressionModel =
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
val logisticRegressionModelExport =
PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
-
+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
}
-
+
test("PMMLModelExportFactory throw IllegalArgumentException "
+ "when passing a Multinomial Logistic Regression") {
/** 3 classes, 2 features */
val multiclassLogisticRegressionModel = new LogisticRegressionModel(
- weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
+ weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
numFeatures = 2, numClasses = 3)
-
+
intercept[IllegalArgumentException] {
PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
index b792d819fdab..a5ca1518f82f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -19,12 +19,11 @@ package org.apache.spark.mllib.random
import scala.math
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.util.StatCounter
// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
-class RandomDataGeneratorSuite extends FunSuite {
+class RandomDataGeneratorSuite extends SparkFunSuite {
def apiChecks(gen: RandomDataGenerator[Double]) {
// resetting seed should generate the same sequence of random numbers
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
index 63f2ea916d45..413db2000d6d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.random
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD}
@@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter
*
* TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
*/
-class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable {
+class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable {
def testGeneratedRDD(rdd: RDD[Double],
expectedSize: Long,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
index 57216e8eb4a5..10f5a2be48f7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.rdd
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
-class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
+class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("topByKey") {
val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5,
1), (3, 5)), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 6d6c0aa5be81..bc6417261483 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.rdd
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.RDDFunctions._
-class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
+class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("sliding") {
val data = 0 until 6
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index b3798940ddc3..05b87728d6fd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -21,9 +21,9 @@ import scala.collection.JavaConversions._
import scala.math.abs
import scala.util.Random
-import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.storage.StorageLevel
@@ -84,7 +84,7 @@ object ALSSuite {
}
-class ALSSuite extends FunSuite with MLlibTestSparkContext {
+class ALSSuite extends SparkFunSuite with MLlibTestSparkContext {
test("rank-1 matrices") {
testALS(50, 100, 1, 15, 0.7, 0.3)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
index 2c92866f3893..2c8ed057a516 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.recommendation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
-class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
+class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext {
val rank = 2
var userFeatures: RDD[(Int, Array[Double])] = _
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
index 3b38bdf5ef5e..ea4f2865757c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
@@ -17,13 +17,14 @@
package org.apache.spark.mllib.regression
-import org.scalatest.{Matchers, FunSuite}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
private def round(d: Double) = {
math.round(d * 100).toDouble / 100
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
index 110c44a7193f..d8364a06de4d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.mllib.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
-class LabeledPointSuite extends FunSuite {
+class LabeledPointSuite extends SparkFunSuite {
test("parse labeled points") {
val points = Seq(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 71dce5092299..08a152ffc7a2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -32,7 +31,7 @@ private object LassoSuite {
val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class LassoSuite extends FunSuite with MLlibTestSparkContext {
+class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -143,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 3781931c2f81..f88a1c33c9f7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -32,7 +31,7 @@ private object LinearRegressionSuite {
val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index d6c93cc0e49c..7a781fee634c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression
import scala.util.Random
import org.jblas.DoubleMatrix
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -33,7 +33,7 @@ private object RidgeRegressionSuite {
val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = {
predictions.zip(input).map { case (prediction, expected) =>
@@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 26604dbe6c1e..f5e2d31056cb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.TestSuiteBase
-class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
+class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 20000
@@ -167,4 +166,22 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
assert((error.head - error.last) > 2)
}
+
+ // Test empty RDDs in a stream
+ test("handling empty RDDs in a stream") {
+ val model = new StreamingLinearRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(0.0, 0.0))
+ .setStepSize(0.2)
+ .setNumIterations(25)
+ val numBatches = 10
+ val nPoints = 100
+ val emptyInput = Seq.empty[Seq[LabeledPoint]]
+ val ssc = setupStreams(emptyInput,
+ (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ }
+ )
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
index a7e6fce31ff7..c292ced75e87 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
@@ -17,16 +17,15 @@
package org.apache.spark.mllib.stat
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
SpearmanCorrelation}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
+class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
// test input data
val xData = Array(1.0, 0.0, -2.0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
index 15418e603596..b084a5fb4313 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
@@ -19,16 +19,14 @@ package org.apache.spark.mllib.stat
import java.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.test.ChiSqTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext {
+class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("chi squared pearson goodness of fit") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
index a309c942cf8f..5feccdf33681 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
@@ -18,11 +18,11 @@
package org.apache.spark.mllib.stat
import org.apache.commons.math3.distribution.NormalDistribution
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
+class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
test("kernel density single sample") {
val rdd = sc.parallelize(Array(5.0))
val evaluationPoints = Array(5.0, 6.0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 23b0eec865de..07efde4f5e6d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.stat
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.TestingUtils._
-class MultivariateOnlineSummarizerSuite extends FunSuite {
+class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
test("basic error handing") {
val summarizer = new MultivariateOnlineSummarizer
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
index fac2498e4dcb..aa60deb665ae 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
@@ -17,49 +17,48 @@
package org.apache.spark.mllib.stat.distribution
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext {
+class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext {
test("univariate") {
val x1 = Vectors.dense(0.0)
val x2 = Vectors.dense(1.5)
-
+
val mu = Vectors.dense(0.0)
val sigma1 = Matrices.dense(1, 1, Array(1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
-
+
val sigma2 = Matrices.dense(1, 1, Array(4.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
}
-
+
test("multivariate") {
val x1 = Vectors.dense(0.0, 0.0)
val x2 = Vectors.dense(1.0, 1.0)
-
+
val mu = Vectors.dense(0.0, 0.0)
val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
-
+
val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
}
-
+
test("multivariate degenerate") {
val x1 = Vectors.dense(0.0, 0.0)
val x2 = Vectors.dense(1.0, 1.0)
-
+
val mu = Vectors.dense(0.0, 0.0)
val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
val dist = new MultivariateGaussian(mu, sigma)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index ce983eb27fa3..356d957f1590 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
-class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
/////////////////////////////////////////////////////////////////////////////
// Tests examining individual elements of training
@@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object DecisionTreeSuite extends FunSuite {
+object DecisionTreeSuite extends SparkFunSuite {
def validateClassifier(
model: DecisionTreeModel,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 55b0bac7d49f..84dd3b342d4c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.mllib.tree
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
@@ -32,7 +31,7 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[GradientBoostedTrees]].
*/
-class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
+class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Regression with continuous features: SquaredError") {
GradientBoostedTreesSuite.testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
index 92b498580af0..49aff21fe791 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.tree
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
/**
* Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
*/
-class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
+class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext {
test("Gini impurity does not support negative labels") {
val gini = new GiniAggregator(2)
intercept[IllegalArgumentException] {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 4ed66953cb62..e6df5d974bf3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -35,7 +34,7 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[RandomForest]].
*/
-class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
index b184e936672c..9d756da41032 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.tree.impl
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.EnsembleTestHelper
import org.apache.spark.mllib.util.MLlibTestSparkContext
/**
* Test suite for [[BaggedPoint]].
*/
-class BaggedPointSuite extends FunSuite with MLlibTestSparkContext {
+class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
test("BaggedPoint RDD: without subsampling") {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index cdece2c174be..70219e9ad9d3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -21,19 +21,19 @@ import java.io.File
import scala.io.Source
-import org.scalatest.FunSuite
-
import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files
+import org.apache.spark.SparkException
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
+class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("epsilon computation") {
assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
@@ -63,7 +63,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
val fastSquaredDist3 =
fastSquaredDistance(v2, norm2, v3, norm3, precision)
assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m")
- if (m > 10) {
+ if (m > 10) {
val v4 = Vectors.sparse(n, indices.slice(0, m - 10),
indices.map(i => a(i) + 0.5).slice(0, m - 10))
val norm4 = Vectors.norm(v4, 2.0)
@@ -109,6 +109,40 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}
+ test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
+ val lines =
+ """
+ |0
+ |0 0:4.0 4:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
+ test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
+ val lines =
+ """
+ |0
+ |0 3:4.0 2:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
test("saveAsLibSVMFile") {
val examples = sc.parallelize(Seq(
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
index f68fb95eac4e..8dcb9ba9be10 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
@@ -17,11 +17,9 @@
package org.apache.spark.mllib.util
-import org.scalatest.FunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.SparkException
-
-class NumericParserSuite extends FunSuite {
+class NumericParserSuite extends SparkFunSuite {
test("parser") {
val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)"
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
index 59e6c778806f..8f475f30249d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.util
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.scalatest.FunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.scalatest.exceptions.TestFailedException
-class TestingUtilsSuite extends FunSuite {
+class TestingUtilsSuite extends SparkFunSuite {
test("Comparing doubles using relative error.") {
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 0c3147761cfc..a85e0a66f4a3 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
index 7dc7c65825e3..4b5bfcb6f04b 100644
--- a/network/shuffle/pom.xml
+++ b/network/shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index 1e2e9c80af6c..a99f7c4392d3 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/pom.xml b/pom.xml
index 711edf9efad2..6d4f717d4931 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOTpomSpark Project Parent POMhttp://spark.apache.org/
@@ -114,11 +114,10 @@
UTF-8UTF-8
- org.spark-project.akka
- 2.3.4-spark
- 1.6
+ com.typesafe.akka
+ 2.3.11
+ 1.7spark
- 2.0.10.21.1shaded-protobuf1.7.10
@@ -137,7 +136,7 @@
0.13.110.10.1.1
- 1.6.0rc3
+ 1.7.01.2.48.1.14.v201310313.0.0.v201112011016
@@ -180,7 +179,7 @@
compile${session.executionRootDirectory}
@@ -269,6 +268,18 @@
false
+
+
+ spark-1.4-staging
+ Spark 1.4 RC4 Staging Repository
+ https://repository.apache.org/content/repositories/orgapachespark-1112
+
+ true
+
+
+ false
+
+
@@ -576,7 +587,7 @@
io.nettynetty-all
- 4.0.23.Final
+ 4.0.28.Finalorg.apache.derby
@@ -1069,13 +1080,13 @@
- com.twitter
+ org.apache.parquetparquet-column${parquet.version}${parquet.deps.scope}
- com.twitter
+ org.apache.parquetparquet-hadoop${parquet.version}${parquet.deps.scope}
@@ -1205,15 +1216,6 @@
-target${java.version}
-
-
-
- org.scalamacros
- paradise_${scala.version}
- ${scala.macros.version}
-
-
@@ -1242,7 +1244,7 @@
**/*Suite.java${project.build.directory}/surefire-reports
- -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
+ -Xmx3g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
+ false
@@ -1542,6 +1550,26 @@
+
+
+ org.apache.maven.plugins
+ maven-antrun-plugin
+
+
+ create-tmp-dir
+ generate-test-resources
+
+ run
+
+
+
+
+
+
+
+
+
+
org.apache.maven.plugins
@@ -1664,6 +1692,8 @@
0.98.7-hadoop1hadoop11.8.8
+ org.spark-project.akka
+ 2.3.4-spark
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index dde92949fa17..5812b72f0aa7 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -91,7 +91,8 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
- val previousSparkVersion = "1.3.0"
+ // TODO: Change this once Spark 1.4.0 is released
+ val previousSparkVersion = "1.4.0-rc4"
val fullId = "spark-" + projectRef.project + "_2.10"
mimaDefaultSettings ++
Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 11b439e7875f..8a93ca299951 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -34,10 +34,31 @@ import com.typesafe.tools.mima.core.ProblemFilters._
object MimaExcludes {
def excludes(version: String) =
version match {
+ case v if v.startsWith("1.5") =>
+ Seq(
+ MimaBuild.excludeSparkPackage("deploy"),
+ // These are needed if checking against the sbt build, since they are part of
+ // the maven-generated artifacts in 1.3.
+ excludePackage("org.spark-project.jetty"),
+ MimaBuild.excludeSparkPackage("unused"),
+ // JavaRDDLike is not meant to be extended by user programs
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.partitioner"),
+ // Mima false positive (was a private[spark] class)
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.util.collection.PairIterator"),
+ // Removing a testing method from a private class
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
+ // SQL execution is considered private.
+ excludePackage("org.apache.spark.sql.execution")
+ )
case v if v.startsWith("1.4") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
MimaBuild.excludeSparkPackage("ml"),
+ // SPARK-7910 Adding a method to get the partioner to JavaRDD,
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"),
// SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"),
// These are needed if checking against the sbt build, since they are part of
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b9515a12bc57..e01720296fed 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -23,11 +23,12 @@ import scala.collection.JavaConversions._
import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
-import sbtunidoc.Plugin.genjavadocSettings
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys}
import net.virtualvoid.sbt.graph.Plugin.graphSettings
+import spray.revolver.RevolverPlugin._
+
object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
@@ -52,6 +53,8 @@ object BuildCommons {
// Root project.
val spark = ProjectRef(buildLocation, "spark")
val sparkHome = buildLocation
+
+ val testTempDir = s"$sparkHome/target/tmp"
}
object SparkBuild extends PomBuild {
@@ -118,7 +121,12 @@ object SparkBuild extends PomBuild {
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
- lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq (
+ lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq(
+ libraryDependencies += compilerPlugin(
+ "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full),
+ scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java")))
+
+ lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq (
javaHome := sys.env.get("JAVA_HOME")
.orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() })
.map(file),
@@ -126,7 +134,7 @@ object SparkBuild extends PomBuild {
retrieveManaged := true,
retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
publishMavenStyle := true,
- unidocGenjavadocVersion := "0.8",
+ unidocGenjavadocVersion := "0.9-spark0",
resolvers += Resolver.mavenLocal,
otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))),
@@ -140,7 +148,9 @@ object SparkBuild extends PomBuild {
javacOptions in (Compile, doc) ++= {
val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3)
if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty
- }
+ },
+
+ javacOptions in Compile ++= Seq("-encoding", "UTF-8")
)
def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = {
@@ -151,7 +161,7 @@ object SparkBuild extends PomBuild {
// Note ordering of these settings matter.
/* Enable shared settings on all projects */
(allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools))
- .foreach(enable(sharedSettings ++ ExludedDependencies.settings))
+ .foreach(enable(sharedSettings ++ ExludedDependencies.settings ++ Revolver.settings))
/* Enable tests settings for all projects except examples, assembly and tools */
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
@@ -174,9 +184,6 @@ object SparkBuild extends PomBuild {
/* Enable unidoc only for the root spark project */
enable(Unidoc.settings)(spark)
- /* Catalyst macro settings */
- enable(Catalyst.settings)(catalyst)
-
/* Spark SQL Core console settings */
enable(SQL.settings)(sql)
@@ -271,14 +278,6 @@ object OldDeps {
)
}
-object Catalyst {
- lazy val settings = Seq(
- addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full),
- // Quasiquotes break compiling scala doc...
- // TODO: Investigate fixing this.
- sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen")))
-}
-
object SQL {
lazy val settings = Seq(
initialCommands in console :=
@@ -503,6 +502,7 @@ object TestSettings {
"SPARK_DIST_CLASSPATH" ->
(fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
"JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))),
+ javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir",
javaOptions in Test += "-Dspark.test.home=" + sparkHome,
javaOptions in Test += "-Dspark.testing=1",
javaOptions in Test += "-Dspark.port.maxRetries=100",
@@ -511,10 +511,11 @@ object TestSettings {
javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
+ javaOptions in Test += "-Dderby.system.durability=test",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
javaOptions in Test += "-ea",
- javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
+ javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
.split(" ").toSeq,
javaOptions += "-Xmx3g",
// Show full stack trace and duration in test cases.
@@ -524,6 +525,13 @@ object TestSettings {
libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test",
// Only allow one test at a time, even across projects, since they run in the same JVM
parallelExecution in Test := false,
+ // Make sure the test temp directory exists.
+ resourceGenerators in Test <+= resourceManaged in Test map { outDir: File =>
+ if (!new File(testTempDir).isDirectory()) {
+ require(new File(testTempDir).mkdirs())
+ }
+ Seq[File]()
+ },
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
// Remove certain packages from Scaladoc
scalacOptions in (Compile, doc) := Seq(
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 7096b0d3ee7d..51820460ca1a 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -25,10 +25,12 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6")
addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1")
-addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1")
+addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3")
addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2")
+addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2")
+
libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3"
libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3"
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 0d21a132048a..adca90ddaf39 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -261,3 +261,7 @@ def _start_update_server():
thread.daemon = True
thread.start()
return server
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index aeb7ad4f2f83..90b2fffbb9c7 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -324,10 +324,18 @@ def stop(self):
with SparkContext._lock:
SparkContext._active_spark_context = None
- def range(self, start, end, step=1, numSlices=None):
+ def emptyRDD(self):
+ """
+ Create an RDD that has no partitions or elements.
+ """
+ return RDD(self._jsc.emptyRDD(), self, NoOpSerializer())
+
+ def range(self, start, end=None, step=1, numSlices=None):
"""
Create a new RDD of int containing elements from `start` to `end`
- (exclusive), increased by `step` every element.
+ (exclusive), increased by `step` every element. Can be called the same
+ way as python's built-in range() function. If called with a single argument,
+ the argument is interpreted as `end`, and `start` is set to 0.
:param start: the start value
:param end: the end value (exclusive)
@@ -335,9 +343,17 @@ def range(self, start, end, step=1, numSlices=None):
:param numSlices: the number of partitions of the new RDD
:return: An RDD of int
+ >>> sc.range(5).collect()
+ [0, 1, 2, 3, 4]
+ >>> sc.range(2, 4).collect()
+ [2, 3]
>>> sc.range(1, 7, 2).collect()
[1, 3, 5]
"""
+ if end is None:
+ end = start
+ start = 0
+
return self.parallelize(xrange(start, end, step), numSlices)
def parallelize(self, c, numSlices=None):
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 497841b6c8ce..0bf988fd72f1 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -91,20 +91,19 @@ class CrossValidator(Estimator):
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.mllib.linalg import Vectors
>>> dataset = sqlContext.createDataFrame(
- ... [(Vectors.dense([0.0, 1.0]), 0.0),
- ... (Vectors.dense([1.0, 2.0]), 1.0),
- ... (Vectors.dense([0.55, 3.0]), 0.0),
- ... (Vectors.dense([0.45, 4.0]), 1.0),
- ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
+ ... [(Vectors.dense([0.0]), 0.0),
+ ... (Vectors.dense([0.4]), 1.0),
+ ... (Vectors.dense([0.5]), 0.0),
+ ... (Vectors.dense([0.6]), 1.0),
+ ... (Vectors.dense([1.0]), 1.0)] * 10,
... ["features", "label"])
>>> lr = LogisticRegression()
- >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
+ >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- >>> # SPARK-7432: The following test is flaky.
- >>> # cvModel = cv.fit(dataset)
- >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
- >>> # cvModel.transform(dataset).collect() == expected.collect()
+ >>> cvModel = cv.fit(dataset)
+ >>> evaluator.evaluate(cvModel.transform(dataset))
+ 0.8333...
"""
# a placeholder to make it appear in the generated doc
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index 07507b2ad0d0..acba3a717d21 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -23,16 +23,10 @@
# MLlib currently needs NumPy 1.4+, so complain if lower
import numpy
-if numpy.version.version < '1.4':
+
+ver = [int(x) for x in numpy.version.version.split('.')[:2]]
+if ver < [1, 4]:
raise Exception("MLlib requires NumPy 1.4+")
__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
'recommendation', 'regression', 'stat', 'tree', 'util']
-
-import sys
-from . import rand as random
-modname = __name__ + '.random'
-random.__name__ = modname
-random.RandomRDDs.__module__ = modname
-sys.modules[modname] = random
-del modname, sys
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index a70c664a71fd..42e41397bf4b 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -33,8 +33,8 @@
class LinearClassificationModel(LinearModel):
"""
- A private abstract class representing a multiclass classification model.
- The categories are represented by int values: 0, 1, 2, etc.
+ A private abstract class representing a multiclass classification
+ model. The categories are represented by int values: 0, 1, 2, etc.
"""
def __init__(self, weights, intercept):
super(LinearClassificationModel, self).__init__(weights, intercept)
@@ -44,10 +44,11 @@ def setThreshold(self, value):
"""
.. note:: Experimental
- Sets the threshold that separates positive predictions from negative
- predictions. An example with prediction score greater than or equal
- to this threshold is identified as an positive, and negative otherwise.
- It is used for binary classification only.
+ Sets the threshold that separates positive predictions from
+ negative predictions. An example with prediction score greater
+ than or equal to this threshold is identified as an positive,
+ and negative otherwise. It is used for binary classification
+ only.
"""
self._threshold = value
@@ -56,8 +57,9 @@ def threshold(self):
"""
.. note:: Experimental
- Returns the threshold (if any) used for converting raw prediction scores
- into 0/1 predictions. It is used for binary classification only.
+ Returns the threshold (if any) used for converting raw
+ prediction scores into 0/1 predictions. It is used for
+ binary classification only.
"""
return self._threshold
@@ -65,22 +67,35 @@ def clearThreshold(self):
"""
.. note:: Experimental
- Clears the threshold so that `predict` will output raw prediction scores.
- It is used for binary classification only.
+ Clears the threshold so that `predict` will output raw
+ prediction scores. It is used for binary classification only.
"""
self._threshold = None
def predict(self, test):
"""
- Predict values for a single data point or an RDD of points using
- the model trained.
+ Predict values for a single data point or an RDD of points
+ using the model trained.
"""
raise NotImplementedError
class LogisticRegressionModel(LinearClassificationModel):
- """A linear binary classification model derived from logistic regression.
+ """
+ Classification model trained using Multinomial/Binary Logistic
+ Regression.
+
+ :param weights: Weights computed for every feature.
+ :param intercept: Intercept computed for this model. (Only used
+ in Binary Logistic Regression. In Multinomial Logistic
+ Regression, the intercepts will not be a single value,
+ so the intercepts will be part of the weights.)
+ :param numFeatures: the dimension of the features.
+ :param numClasses: the number of possible outcomes for k classes
+ classification problem in Multinomial Logistic Regression.
+ By default, it is binary logistic regression so numClasses
+ will be set to 2.
>>> data = [
... LabeledPoint(0.0, [0.0, 1.0]),
@@ -161,8 +176,8 @@ def numClasses(self):
def predict(self, x):
"""
- Predict values for a single data point or an RDD of points using
- the model trained.
+ Predict values for a single data point or an RDD of points
+ using the model trained.
"""
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))
@@ -225,16 +240,19 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
"""
Train a logistic regression model on the given data.
- :param data: The training data, an RDD of LabeledPoint.
- :param iterations: The number of iterations (default: 100).
+ :param data: The training data, an RDD of
+ LabeledPoint.
+ :param iterations: The number of iterations
+ (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
- :param miniBatchFraction: Fraction of data to be used for each SGD
- iteration.
+ :param miniBatchFraction: Fraction of data to be used for each
+ SGD iteration (default: 1.0).
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 0.01).
- :param regType: The type of regularizer used for training
- our model.
+ :param regParam: The regularizer parameter
+ (default: 0.01).
+ :param regType: The type of regularizer used for
+ training our model.
:Allowed values:
- "l1" for using L1 regularization
@@ -243,13 +261,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
(default: "l2")
- :param intercept: Boolean parameter which indicates the use
- or not of the augmented representation for
- training data (i.e. whether bias features
- are activated or not).
- :param validateData: Boolean parameter which indicates if the
- algorithm should validate data before training.
- (default: True)
+ :param intercept: Boolean parameter which indicates the
+ use or not of the augmented representation
+ for training data (i.e. whether bias
+ features are activated or not,
+ default: False).
+ :param validateData: Boolean parameter which indicates if
+ the algorithm should validate data
+ before training. (default: True)
"""
def train(rdd, i):
return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations),
@@ -267,12 +286,15 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
"""
Train a logistic regression model on the given data.
- :param data: The training data, an RDD of LabeledPoint.
- :param iterations: The number of iterations (default: 100).
+ :param data: The training data, an RDD of
+ LabeledPoint.
+ :param iterations: The number of iterations
+ (default: 100).
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 0.01).
- :param regType: The type of regularizer used for training
- our model.
+ :param regParam: The regularizer parameter
+ (default: 0.01).
+ :param regType: The type of regularizer used for
+ training our model.
:Allowed values:
- "l1" for using L1 regularization
@@ -281,19 +303,21 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
(default: "l2")
- :param intercept: Boolean parameter which indicates the use
- or not of the augmented representation for
- training data (i.e. whether bias features
- are activated or not).
- :param corrections: The number of corrections used in the LBFGS
- update (default: 10).
- :param tolerance: The convergence tolerance of iterations for
- L-BFGS (default: 1e-4).
+ :param intercept: Boolean parameter which indicates the
+ use or not of the augmented representation
+ for training data (i.e. whether bias
+ features are activated or not,
+ default: False).
+ :param corrections: The number of corrections used in the
+ LBFGS update (default: 10).
+ :param tolerance: The convergence tolerance of iterations
+ for L-BFGS (default: 1e-4).
:param validateData: Boolean parameter which indicates if the
- algorithm should validate data before training.
- (default: True)
- :param numClasses: The number of classes (i.e., outcomes) a label can take
- in Multinomial Logistic Regression (default: 2).
+ algorithm should validate data before
+ training. (default: True)
+ :param numClasses: The number of classes (i.e., outcomes) a
+ label can take in Multinomial Logistic
+ Regression (default: 2).
>>> data = [
... LabeledPoint(0.0, [0.0, 1.0]),
@@ -323,7 +347,11 @@ def train(rdd, i):
class SVMModel(LinearClassificationModel):
- """A support vector machine.
+ """
+ Model for Support Vector Machines (SVMs).
+
+ :param weights: Weights computed for every feature.
+ :param intercept: Intercept computed for this model.
>>> data = [
... LabeledPoint(0.0, [0.0]),
@@ -370,8 +398,8 @@ def __init__(self, weights, intercept):
def predict(self, x):
"""
- Predict values for a single data point or an RDD of points using
- the model trained.
+ Predict values for a single data point or an RDD of points
+ using the model trained.
"""
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))
@@ -409,16 +437,19 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
"""
Train a support vector machine on the given data.
- :param data: The training data, an RDD of LabeledPoint.
- :param iterations: The number of iterations (default: 100).
+ :param data: The training data, an RDD of
+ LabeledPoint.
+ :param iterations: The number of iterations
+ (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
- :param regParam: The regularizer parameter (default: 0.01).
- :param miniBatchFraction: Fraction of data to be used for each SGD
- iteration.
+ :param regParam: The regularizer parameter
+ (default: 0.01).
+ :param miniBatchFraction: Fraction of data to be used for each
+ SGD iteration (default: 1.0).
:param initialWeights: The initial weights (default: None).
- :param regType: The type of regularizer used for training
- our model.
+ :param regType: The type of regularizer used for
+ training our model.
:Allowed values:
- "l1" for using L1 regularization
@@ -427,13 +458,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
(default: "l2")
- :param intercept: Boolean parameter which indicates the use
- or not of the augmented representation for
- training data (i.e. whether bias features
- are activated or not).
- :param validateData: Boolean parameter which indicates if the
- algorithm should validate data before training.
- (default: True)
+ :param intercept: Boolean parameter which indicates the
+ use or not of the augmented representation
+ for training data (i.e. whether bias
+ features are activated or not,
+ default: False).
+ :param validateData: Boolean parameter which indicates if
+ the algorithm should validate data
+ before training. (default: True)
"""
def train(rdd, i):
return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step),
@@ -449,9 +481,11 @@ class NaiveBayesModel(Saveable, Loader):
"""
Model for Naive Bayes classifiers.
- Contains two parameters:
- - pi: vector of logs of class priors (dimension C)
- - theta: matrix of logs of class conditional probabilities (CxD)
+ :param labels: list of labels.
+ :param pi: log of class priors, whose dimension is C,
+ number of labels.
+ :param theta: log of class conditional probabilities, whose
+ dimension is C-by-D, where D is number of features.
>>> data = [
... LabeledPoint(0.0, [0.0, 0.0]),
@@ -493,7 +527,10 @@ def __init__(self, labels, pi, theta):
self.theta = theta
def predict(self, x):
- """Return the most likely class for a data vector or an RDD of vectors"""
+ """
+ Return the most likely class for a data vector
+ or an RDD of vectors
+ """
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))
x = _convert_to_vector(x)
@@ -523,16 +560,18 @@ class NaiveBayes(object):
@classmethod
def train(cls, data, lambda_=1.0):
"""
- Train a Naive Bayes model given an RDD of (label, features) vectors.
+ Train a Naive Bayes model given an RDD of (label, features)
+ vectors.
- This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which can
- handle all kinds of discrete data. For example, by converting
- documents into TF-IDF vectors, it can be used for document
- classification. By making every vector a 0-1 vector, it can also be
- used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
+ This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which
+ can handle all kinds of discrete data. For example, by
+ converting documents into TF-IDF vectors, it can be used for
+ document classification. By making every vector a 0-1 vector,
+ it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
+ The input feature values must be nonnegative.
:param data: RDD of LabeledPoint.
- :param lambda_: The smoothing parameter
+ :param lambda_: The smoothing parameter (default: 1.0).
"""
first = data.first()
if not isinstance(first, LabeledPoint):
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index aab5e5f4b77b..c5cf3a4e7ff2 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
Evaluator for binary classification.
+ :param scoreAndLabels: an RDD of (score, label) pairs
+
>>> scoreAndLabels = sc.parallelize([
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
>>> metrics = BinaryClassificationMetrics(scoreAndLabels)
@@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
def __init__(self, scoreAndLabels):
- """
- :param scoreAndLabels: an RDD of (score, label) pairs
- """
sc = scoreAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
@@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper):
"""
Evaluator for regression.
+ :param predictionAndObservations: an RDD of (prediction,
+ observation) pairs.
+
>>> predictionAndObservations = sc.parallelize([
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
>>> metrics = RegressionMetrics(predictionAndObservations)
@@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndObservations):
- """
- :param predictionAndObservations: an RDD of (prediction, observation) pairs.
- """
sc = predictionAndObservations.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
@@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper):
"""
Evaluator for multiclass classification.
+ :param predictionAndLabels an RDD of (prediction, label) pairs.
+
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
@@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndLabels):
- """
- :param predictionAndLabels an RDD of (prediction, label) pairs.
- """
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
@@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper):
"""
Evaluator for ranking algorithms.
+ :param predictionAndLabels: an RDD of (predicted ranking,
+ ground truth set) pairs.
+
>>> predictionAndLabels = sc.parallelize([
... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
@@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndLabels):
- """
- :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
- """
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels,
@@ -347,6 +345,10 @@ class MultilabelMetrics(JavaModelWrapper):
"""
Evaluator for multilabel classification.
+ :param predictionAndLabels: an RDD of (predictions, labels) pairs,
+ both are non-null Arrays, each with
+ unique elements.
+
>>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index aac305db6c19..da90554f4143 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -68,6 +68,8 @@ class Normalizer(VectorTransformer):
For `p` = float('inf'), max(abs(vector)) will be used as norm for
normalization.
+ :param p: Normalization in L^p^ space, p = 2 by default.
+
>>> v = Vectors.dense(range(3))
>>> nor = Normalizer(1)
>>> nor.transform(v)
@@ -82,9 +84,6 @@ class Normalizer(VectorTransformer):
DenseVector([0.0, 0.5, 1.0])
"""
def __init__(self, p=2.0):
- """
- :param p: Normalization in L^p^ space, p = 2 by default.
- """
assert p >= 1.0, "p should be greater than 1.0"
self.p = float(p)
@@ -94,7 +93,7 @@ def transform(self, vector):
:param vector: vector or RDD of vector to be normalized.
:return: normalized vector. If the norm of the input is zero, it
- will return the input vector.
+ will return the input vector.
"""
sc = SparkContext._active_spark_context
assert sc is not None, "SparkContext should be initialized first"
@@ -164,6 +163,13 @@ class StandardScaler(object):
variance using column summary statistics on the samples in the
training set.
+ :param withMean: False by default. Centers the data with mean
+ before scaling. It will build a dense output, so this
+ does not work on sparse input and will raise an
+ exception.
+ :param withStd: True by default. Scales the data to unit
+ standard deviation.
+
>>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])]
>>> dataset = sc.parallelize(vs)
>>> standardizer = StandardScaler(True, True)
@@ -174,14 +180,6 @@ class StandardScaler(object):
DenseVector([0.7071, -0.7071, 0.7071])
"""
def __init__(self, withMean=False, withStd=True):
- """
- :param withMean: False by default. Centers the data with mean
- before scaling. It will build a dense output, so this
- does not work on sparse input and will raise an
- exception.
- :param withStd: True by default. Scales the data to unit
- standard deviation.
- """
if not (withMean or withStd):
warnings.warn("Both withMean and withStd are false. The model does nothing.")
self.withMean = withMean
@@ -193,7 +191,7 @@ def fit(self, dataset):
for later scaling.
:param data: The data used to compute the mean and variance
- to build the transformation model.
+ to build the transformation model.
:return: a StandardScalarModel
"""
dataset = dataset.map(_convert_to_vector)
@@ -223,6 +221,8 @@ class ChiSqSelector(object):
Creates a ChiSquared feature selector.
+ :param numTopFeatures: number of features that selector will select.
+
>>> data = [
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
@@ -236,9 +236,6 @@ class ChiSqSelector(object):
DenseVector([5.0])
"""
def __init__(self, numTopFeatures):
- """
- :param numTopFeatures: number of features that selector will select.
- """
self.numTopFeatures = int(numTopFeatures)
def fit(self, data):
@@ -246,9 +243,9 @@ def fit(self, data):
Returns a ChiSquared feature selector.
:param data: an `RDD[LabeledPoint]` containing the labeled dataset
- with categorical features. Real-valued features will be
- treated as categorical for each distinct value.
- Apply feature discretizer before using this function.
+ with categorical features. Real-valued features will be
+ treated as categorical for each distinct value.
+ Apply feature discretizer before using this function.
"""
jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
return ChiSqSelectorModel(jmodel)
@@ -263,15 +260,14 @@ class HashingTF(object):
Note: the terms must be hashable (can not be dict/set/list...).
+ :param numFeatures: number of features (default: 2^20)
+
>>> htf = HashingTF(100)
>>> doc = "a a b b c d".split(" ")
>>> htf.transform(doc)
SparseVector(100, {...})
"""
def __init__(self, numFeatures=1 << 20):
- """
- :param numFeatures: number of features (default: 2^20)
- """
self.numFeatures = numFeatures
def indexOf(self, term):
@@ -311,7 +307,7 @@ def transform(self, x):
Call transform directly on the RDD instead.
:param x: an RDD of term frequency vectors or a term frequency
- vector
+ vector
:return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
if isinstance(x, RDD):
@@ -342,6 +338,9 @@ class IDF(object):
`minDocFreq`). For terms that are not in at least `minDocFreq`
documents, the IDF is found as 0, resulting in TF-IDFs of 0.
+ :param minDocFreq: minimum of documents in which a term
+ should appear for filtering
+
>>> n = 4
>>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)),
... Vectors.dense([0.0, 1.0, 2.0, 3.0]),
@@ -362,10 +361,6 @@ class IDF(object):
SparseVector(4, {1: 0.0, 3: 0.5754})
"""
def __init__(self, minDocFreq=0):
- """
- :param minDocFreq: minimum of documents in which a term
- should appear for filtering
- """
self.minDocFreq = minDocFreq
def fit(self, dataset):
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 23d1a79ffe51..e96c5ef87df8 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -36,7 +36,7 @@
import numpy as np
from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
- IntegerType, ByteType
+ IntegerType, ByteType, BooleanType
__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
@@ -163,6 +163,59 @@ def simpleString(self):
return "vector"
+class MatrixUDT(UserDefinedType):
+ """
+ SQL user-defined type (UDT) for Matrix.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return StructType([
+ StructField("type", ByteType(), False),
+ StructField("numRows", IntegerType(), False),
+ StructField("numCols", IntegerType(), False),
+ StructField("colPtrs", ArrayType(IntegerType(), False), True),
+ StructField("rowIndices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True),
+ StructField("isTransposed", BooleanType(), False)])
+
+ @classmethod
+ def module(cls):
+ return "pyspark.mllib.linalg"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.mllib.linalg.MatrixUDT"
+
+ def serialize(self, obj):
+ if isinstance(obj, SparseMatrix):
+ colPtrs = [int(i) for i in obj.colPtrs]
+ rowIndices = [int(i) for i in obj.rowIndices]
+ values = [float(v) for v in obj.values]
+ return (0, obj.numRows, obj.numCols, colPtrs,
+ rowIndices, values, bool(obj.isTransposed))
+ elif isinstance(obj, DenseMatrix):
+ values = [float(v) for v in obj.values]
+ return (1, obj.numRows, obj.numCols, None, None, values,
+ bool(obj.isTransposed))
+ else:
+ raise TypeError("cannot serialize type %r" % (type(obj)))
+
+ def deserialize(self, datum):
+ assert len(datum) == 7, \
+ "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
+ tpe = datum[0]
+ if tpe == 0:
+ return SparseMatrix(*datum[1:])
+ elif tpe == 1:
+ return DenseMatrix(datum[1], datum[2], datum[5], datum[6])
+ else:
+ raise ValueError("do not recognize type %r" % tpe)
+
+ def simpleString(self):
+ return "matrix"
+
+
class Vector(object):
__UDT__ = VectorUDT()
@@ -781,10 +834,12 @@ def zeros(size):
class Matrix(object):
+
+ __UDT__ = MatrixUDT()
+
"""
Represents a local matrix.
"""
-
def __init__(self, numRows, numCols, isTransposed=False):
self.numRows = numRows
self.numCols = numCols
diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py
similarity index 100%
rename from python/pyspark/mllib/rand.py
rename to python/pyspark/mllib/random.py
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 41bde2ce3e60..0c4d7d3bbee0 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -33,12 +33,12 @@
class LabeledPoint(object):
"""
- The features and labels of a data point.
+ Class that represents the features and labels of a data point.
:param label: Label for this data point.
:param features: Vector of features for this point (NumPy array,
- list, pyspark.mllib.linalg.SparseVector, or scipy.sparse
- column matrix)
+ list, pyspark.mllib.linalg.SparseVector, or scipy.sparse
+ column matrix)
Note: 'label' and 'features' are accessible as class attributes.
"""
@@ -59,7 +59,12 @@ def __repr__(self):
class LinearModel(object):
- """A linear model that has a vector of coefficients and an intercept."""
+ """
+ A linear model that has a vector of coefficients and an intercept.
+
+ :param weights: Weights computed for every feature.
+ :param intercept: Intercept computed for this model.
+ """
def __init__(self, weights, intercept):
self._coeff = _convert_to_vector(weights)
@@ -193,18 +198,28 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
initialWeights=None, regParam=0.0, regType=None, intercept=False,
validateData=True):
"""
- Train a linear regression model on the given data.
-
- :param data: The training data.
- :param iterations: The number of iterations (default: 100).
+ Train a linear regression model using Stochastic Gradient
+ Descent (SGD).
+ This solves the least squares regression formulation
+ f(weights) = 1/n ||A weights-y||^2^
+ (which is the mean squared error).
+ Here the data matrix has n rows, and the input RDD holds the
+ set of rows of A, each with its corresponding right hand side
+ label y. See also the documentation for the precise formulation.
+
+ :param data: The training data, an RDD of
+ LabeledPoint.
+ :param iterations: The number of iterations
+ (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
- :param miniBatchFraction: Fraction of data to be used for each SGD
- iteration.
+ :param miniBatchFraction: Fraction of data to be used for each
+ SGD iteration (default: 1.0).
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 0.0).
- :param regType: The type of regularizer used for training
- our model.
+ :param regParam: The regularizer parameter
+ (default: 0.0).
+ :param regType: The type of regularizer used for
+ training our model.
:Allowed values:
- "l1" for using L1 regularization (lasso),
@@ -213,13 +228,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
(default: None)
- :param intercept: Boolean parameter which indicates the use
- or not of the augmented representation for
- training data (i.e. whether bias features
- are activated or not). (default: False)
- :param validateData: Boolean parameter which indicates if the
- algorithm should validate data before training.
- (default: True)
+ :param intercept: Boolean parameter which indicates the
+ use or not of the augmented representation
+ for training data (i.e. whether bias
+ features are activated or not,
+ default: False).
+ :param validateData: Boolean parameter which indicates if
+ the algorithm should validate data
+ before training. (default: True)
"""
def train(rdd, i):
return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
@@ -232,8 +248,8 @@ def train(rdd, i):
@inherit_doc
class LassoModel(LinearRegressionModelBase):
- """A linear regression model derived from a least-squares fit with an
- l_1 penalty term.
+ """A linear regression model derived from a least-squares fit with
+ an l_1 penalty term.
>>> from pyspark.mllib.regression import LabeledPoint
>>> data = [
@@ -304,7 +320,36 @@ class LassoWithSGD(object):
def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None, intercept=False,
validateData=True):
- """Train a Lasso regression model on the given data."""
+ """
+ Train a regression model with L1-regularization using
+ Stochastic Gradient Descent.
+ This solves the l1-regularized least squares regression
+ formulation
+ f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1
+ Here the data matrix has n rows, and the input RDD holds the
+ set of rows of A, each with its corresponding right hand side
+ label y. See also the documentation for the precise formulation.
+
+ :param data: The training data, an RDD of
+ LabeledPoint.
+ :param iterations: The number of iterations
+ (default: 100).
+ :param step: The step parameter used in SGD
+ (default: 1.0).
+ :param regParam: The regularizer parameter
+ (default: 0.01).
+ :param miniBatchFraction: Fraction of data to be used for each
+ SGD iteration (default: 1.0).
+ :param initialWeights: The initial weights (default: None).
+ :param intercept: Boolean parameter which indicates the
+ use or not of the augmented representation
+ for training data (i.e. whether bias
+ features are activated or not,
+ default: False).
+ :param validateData: Boolean parameter which indicates if
+ the algorithm should validate data
+ before training. (default: True)
+ """
def train(rdd, i):
return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step),
float(regParam), float(miniBatchFraction), i, bool(intercept),
@@ -316,8 +361,8 @@ def train(rdd, i):
@inherit_doc
class RidgeRegressionModel(LinearRegressionModelBase):
- """A linear regression model derived from a least-squares fit with an
- l_2 penalty term.
+ """A linear regression model derived from a least-squares fit with
+ an l_2 penalty term.
>>> from pyspark.mllib.regression import LabeledPoint
>>> data = [
@@ -389,7 +434,36 @@ class RidgeRegressionWithSGD(object):
def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None, intercept=False,
validateData=True):
- """Train a ridge regression model on the given data."""
+ """
+ Train a regression model with L2-regularization using
+ Stochastic Gradient Descent.
+ This solves the l2-regularized least squares regression
+ formulation
+ f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^
+ Here the data matrix has n rows, and the input RDD holds the
+ set of rows of A, each with its corresponding right hand side
+ label y. See also the documentation for the precise formulation.
+
+ :param data: The training data, an RDD of
+ LabeledPoint.
+ :param iterations: The number of iterations
+ (default: 100).
+ :param step: The step parameter used in SGD
+ (default: 1.0).
+ :param regParam: The regularizer parameter
+ (default: 0.01).
+ :param miniBatchFraction: Fraction of data to be used for each
+ SGD iteration (default: 1.0).
+ :param initialWeights: The initial weights (default: None).
+ :param intercept: Boolean parameter which indicates the
+ use or not of the augmented representation
+ for training data (i.e. whether bias
+ features are activated or not,
+ default: False).
+ :param validateData: Boolean parameter which indicates if
+ the algorithm should validate data
+ before training. (default: True)
+ """
def train(rdd, i):
return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step),
float(regParam), float(miniBatchFraction), i, bool(intercept),
@@ -400,7 +474,15 @@ def train(rdd, i):
class IsotonicRegressionModel(Saveable, Loader):
- """Regression model for isotonic regression.
+ """
+ Regression model for isotonic regression.
+
+ :param boundaries: Array of boundaries for which predictions are
+ known. Boundaries must be sorted in increasing order.
+ :param predictions: Array of predictions associated to the
+ boundaries at the same index. Results of isotonic
+ regression and therefore monotone.
+ :param isotonic: indicates whether this is isotonic or antitonic.
>>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)]
>>> irm = IsotonicRegression.train(sc.parallelize(data))
@@ -430,6 +512,25 @@ def __init__(self, boundaries, predictions, isotonic):
self.isotonic = isotonic
def predict(self, x):
+ """
+ Predict labels for provided features.
+ Using a piecewise linear function.
+ 1) If x exactly matches a boundary then associated prediction
+ is returned. In case there are multiple predictions with the
+ same boundary then one of them is returned. Which one is
+ undefined (same as java.util.Arrays.binarySearch).
+ 2) If x is lower or higher than all boundaries then first or
+ last prediction is returned respectively. In case there are
+ multiple predictions with the same boundary then the lowest
+ or highest is returned respectively.
+ 3) If x falls between two values in boundary array then
+ prediction is treated as piecewise linear function and
+ interpolated value is returned. In case there are multiple
+ values with the same boundary then the same rules as in 2)
+ are used.
+
+ :param x: Feature or RDD of Features to be labeled.
+ """
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))
return np.interp(x, self.boundaries, self.predictions)
@@ -451,15 +552,15 @@ def load(cls, sc, path):
class IsotonicRegression(object):
- """
- Run IsotonicRegression algorithm to obtain isotonic regression model.
- :param data: RDD of (label, feature, weight) tuples.
- :param isotonic: Whether this is isotonic or antitonic.
- """
@classmethod
def train(cls, data, isotonic=True):
- """Train a isotonic regression model on the given data."""
+ """
+ Train a isotonic regression model on the given data.
+
+ :param data: RDD of (label, feature, weight) tuples.
+ :param isotonic: Whether this is isotonic or antitonic.
+ """
boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel",
data.map(_convert_to_vector), bool(isotonic))
return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic)
diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py
new file mode 100644
index 000000000000..7da921976d4d
--- /dev/null
+++ b/python/pyspark/mllib/stat/KernelDensity.py
@@ -0,0 +1,61 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+
+if sys.version > '3':
+ xrange = range
+
+import numpy as np
+
+from pyspark.mllib.common import callMLlibFunc
+from pyspark.rdd import RDD
+
+
+class KernelDensity(object):
+ """
+ .. note:: Experimental
+
+ Estimate probability density at required points given a RDD of samples
+ from the population.
+
+ >>> kd = KernelDensity()
+ >>> sample = sc.parallelize([0.0, 1.0])
+ >>> kd.setSample(sample)
+ >>> kd.estimate([0.0, 1.0])
+ array([ 0.12938758, 0.12938758])
+ """
+ def __init__(self):
+ self._bandwidth = 1.0
+ self._sample = None
+
+ def setBandwidth(self, bandwidth):
+ """Set bandwidth of each sample. Defaults to 1.0"""
+ self._bandwidth = bandwidth
+
+ def setSample(self, sample):
+ """Set sample points from the population. Should be a RDD"""
+ if not isinstance(sample, RDD):
+ raise TypeError("samples should be a RDD, received %s" % type(sample))
+ self._sample = sample
+
+ def estimate(self, points):
+ """Estimate the probability density at points"""
+ points = list(points)
+ densities = callMLlibFunc(
+ "estimateKernelDensity", self._sample, self._bandwidth, points)
+ return np.asarray(densities)
diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py
index e3e128513e0d..c8a721d3fe41 100644
--- a/python/pyspark/mllib/stat/__init__.py
+++ b/python/pyspark/mllib/stat/__init__.py
@@ -22,6 +22,7 @@
from pyspark.mllib.stat._statistics import *
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.stat.test import ChiSqTestResult
+from pyspark.mllib.stat.KernelDensity import KernelDensity
__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult",
- "MultivariateGaussian"]
+ "MultivariateGaussian", "KernelDensity"]
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 36a4c7a5408c..f4c997261ef4 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -39,7 +39,7 @@
from pyspark import SparkContext
from pyspark.mllib.common import _to_java_object_rdd
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
- DenseMatrix, SparseMatrix, Vectors, Matrices
+ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -507,6 +507,38 @@ def test_infer_schema(self):
raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
+class MatrixUDTTests(MLlibTestCase):
+
+ dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
+ dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
+ sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
+ sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
+ udt = MatrixUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
+ self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
+ df = rdd.toDF()
+ schema = df.schema
+ self.assertTrue(schema.fields[1].dataType, self.udt)
+ matrices = df.map(lambda x: x._2).collect()
+ self.assertEqual(len(matrices), 2)
+ for m in matrices:
+ if isinstance(m, DenseMatrix):
+ self.assertTrue(m, self.dm1)
+ elif isinstance(m, SparseMatrix):
+ self.assertTrue(m, self.sm1)
+ else:
+ raise ValueError("Expected a matrix but got type %r" % type(m))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(MLlibTestCase):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 98a8ff860636..20c0bc93f413 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -960,7 +960,7 @@ def sum(self):
>>> sc.parallelize([1.0, 2.0, 3.0]).sum()
6.0
"""
- return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
+ return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
def count(self):
"""
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 8fee92ae3aed..ad9c891ba1c0 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -45,22 +45,19 @@
def since(version):
+ """
+ A decorator that annotates a function to append the version of Spark the function was added.
+ """
+ import re
+ indent_p = re.compile(r'\n( +)')
+
def deco(f):
- f.__doc__ = f.__doc__.rstrip() + "\n\n.. versionadded:: %s" % version
+ indents = indent_p.findall(f.__doc__)
+ indent = ' ' * (min(len(m) for m in indents) if indents else 0)
+ f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
return f
return deco
-# fix the module name conflict for Python 3+
-import sys
-from . import _types as types
-modname = __name__ + '.types'
-types.__name__ = modname
-# update the __module__ for all objects, make them picklable
-for v in types.__dict__.values():
- if hasattr(v, "__module__") and v.__module__.endswith('._types'):
- v.__module__ = modname
-sys.modules[modname] = types
-del modname, sys
from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
@@ -70,7 +67,9 @@ def deco(f):
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec
+
__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
+ 'DataFrameReader', 'DataFrameWriter'
]
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 8dc5039f587f..1ecec5b12650 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -315,6 +315,14 @@ def between(self, lowerBound, upperBound):
"""
A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
+
+ >>> df.select(df.name, df.age.between(2, 4)).show()
+ +-----+--------------------------+
+ | name|((age >= 2) && (age <= 4))|
+ +-----+--------------------------+
+ |Alice| true|
+ | Bob| false|
+ +-----+--------------------------+
"""
return (self >= lowerBound) & (self <= upperBound)
@@ -328,12 +336,20 @@ def when(self, condition, value):
:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.
+
+ >>> from pyspark.sql import functions as F
+ >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
+ +-----+--------------------------------------------------------+
+ | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
+ +-----+--------------------------------------------------------+
+ |Alice| -1|
+ | Bob| 1|
+ +-----+--------------------------------------------------------+
"""
- sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
v = value._jc if isinstance(value, Column) else value
- jc = sc._jvm.functions.when(condition._jc, v)
+ jc = self._jc.when(condition._jc, v)
return Column(jc)
@since(1.4)
@@ -345,9 +361,18 @@ def otherwise(self, value):
See :func:`pyspark.sql.functions.when` for example usage.
:param value: a literal value, or a :class:`Column` expression.
+
+ >>> from pyspark.sql import functions as F
+ >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
+ +-----+---------------------------------+
+ | name|CASE WHEN (age > 3) THEN 1 ELSE 0|
+ +-----+---------------------------------+
+ |Alice| 0|
+ | Bob| 1|
+ +-----+---------------------------------+
"""
v = value._jc if isinstance(value, Column) else value
- jc = self._jc.otherwise(value)
+ jc = self._jc.otherwise(v)
return Column(jc)
@since(1.4)
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 22f6257dfe02..599c9ac5794a 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -124,11 +124,14 @@ def getConf(self, key, defaultValue):
@property
@since("1.3.1")
def udf(self):
- """Returns a :class:`UDFRegistration` for UDF registration."""
+ """Returns a :class:`UDFRegistration` for UDF registration.
+
+ :return: :class:`UDFRegistration`
+ """
return UDFRegistration(self)
@since(1.4)
- def range(self, start, end, step=1, numPartitions=None):
+ def range(self, start, end=None, step=1, numPartitions=None):
"""
Create a :class:`DataFrame` with single LongType column named `id`,
containing elements in a range from `start` to `end` (exclusive) with
@@ -138,14 +141,24 @@ def range(self, start, end, step=1, numPartitions=None):
:param end: the end value (exclusive)
:param step: the incremental step (default: 1)
:param numPartitions: the number of partitions of the DataFrame
- :return: A new DataFrame
+ :return: :class:`DataFrame`
>>> sqlContext.range(1, 7, 2).collect()
[Row(id=1), Row(id=3), Row(id=5)]
+
+ If only one argument is specified, it will be used as the end value.
+
+ >>> sqlContext.range(3).collect()
+ [Row(id=0), Row(id=1), Row(id=2)]
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
- jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+
+ if end is None:
+ jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions))
+ else:
+ jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+
return DataFrame(jdf, self)
@ignore_unicode_prefix
@@ -195,8 +208,8 @@ def _inferSchema(self, rdd, samplingRatio=None):
raise ValueError("The first row in RDD is empty, "
"can not infer schema")
if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated,"
- "please use pyspark.sql.Row instead")
+ warnings.warn("Using RDD of dict to inferSchema is deprecated. "
+ "Use pyspark.sql.Row instead")
if samplingRatio is None:
schema = _infer_schema(first)
@@ -219,7 +232,7 @@ def inferSchema(self, rdd, samplingRatio=None):
"""
.. note:: Deprecated in 1.3, use :func:`createDataFrame` instead.
"""
- warnings.warn("inferSchema is deprecated, please use createDataFrame instead")
+ warnings.warn("inferSchema is deprecated, please use createDataFrame instead.")
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
@@ -262,6 +275,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
:class:`list`, or :class:`pandas.DataFrame`.
:param schema: a :class:`StructType` or list of column names. default None.
:param samplingRatio: the sample ratio of rows used for inferring
+ :return: :class:`DataFrame`
>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
@@ -359,18 +373,15 @@ def registerDataFrameAsTable(self, df, tableName):
else:
raise ValueError("Can only register DataFrame as table")
- @since(1.0)
def parquetFile(self, *paths):
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
- >>> import tempfile, shutil
- >>> parquetFile = tempfile.mkdtemp()
- >>> shutil.rmtree(parquetFile)
- >>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlContext.parquetFile(parquetFile)
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
+ .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead.
+
+ >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes
+ [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
+ warnings.warn("parquetFile is deprecated. Use read.parquet() instead.")
gateway = self._sc._gateway
jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths))
for i in range(0, len(paths)):
@@ -378,39 +389,15 @@ def parquetFile(self, *paths):
jdf = self._ssql_ctx.parquetFile(jpaths)
return DataFrame(jdf, self)
- @since(1.0)
def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""Loads a text file storing one JSON object per line as a :class:`DataFrame`.
- If the schema is provided, applies the given schema to this JSON dataset.
- Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead.
- >>> import tempfile, shutil
- >>> jsonFile = tempfile.mkdtemp()
- >>> shutil.rmtree(jsonFile)
- >>> with open(jsonFile, 'w') as f:
- ... f.writelines(jsonStrings)
- >>> df1 = sqlContext.jsonFile(jsonFile)
- >>> df1.printSchema()
- root
- |-- field1: long (nullable = true)
- |-- field2: string (nullable = true)
- |-- field3: struct (nullable = true)
- | |-- field4: long (nullable = true)
-
- >>> from pyspark.sql.types import *
- >>> schema = StructType([
- ... StructField("field2", StringType()),
- ... StructField("field3",
- ... StructType([StructField("field5", ArrayType(IntegerType()))]))])
- >>> df2 = sqlContext.jsonFile(jsonFile, schema)
- >>> df2.printSchema()
- root
- |-- field2: string (nullable = true)
- |-- field3: struct (nullable = true)
- | |-- field5: array (nullable = true)
- | | |-- element: integer (containsNull = true)
+ >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes
+ [('age', 'bigint'), ('name', 'string')]
"""
+ warnings.warn("jsonFile is deprecated. Use read.json() instead.")
if schema is None:
df = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
@@ -462,21 +449,16 @@ def func(iterator):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)
- @since(1.3)
def load(self, path=None, source=None, schema=None, **options):
"""Returns the dataset in a data source as a :class:`DataFrame`.
- The data source is specified by the ``source`` and a set of ``options``.
- If ``source`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Optionally, a schema can be provided as the schema of the returned DataFrame.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead.
"""
+ warnings.warn("load is deprecated. Use read.load() instead.")
return self.read.load(path, source, schema, **options)
@since(1.3)
- def createExternalTable(self, tableName, path=None, source=None,
- schema=None, **options):
+ def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
"""Creates an external table based on the dataset in a data source.
It returns the DataFrame associated with the external table.
@@ -487,6 +469,8 @@ def createExternalTable(self, tableName, path=None, source=None,
Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
created external table.
+
+ :return: :class:`DataFrame`
"""
if path is not None:
options["path"] = path
@@ -508,6 +492,8 @@ def createExternalTable(self, tableName, path=None, source=None,
def sql(self, sqlQuery):
"""Returns a :class:`DataFrame` representing the result of the given query.
+ :return: :class:`DataFrame`
+
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
@@ -519,6 +505,8 @@ def sql(self, sqlQuery):
def table(self, tableName):
"""Returns the specified table as a :class:`DataFrame`.
+ :return: :class:`DataFrame`
+
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -536,6 +524,9 @@ def tables(self, dbName=None):
The returned DataFrame has two columns: ``tableName`` and ``isTemporary``
(a column with :class:`BooleanType` indicating if a table is a temporary one or not).
+ :param dbName: string, name of the database to use.
+ :return: :class:`DataFrame`
+
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.tables()
>>> df2.filter("tableName = 'table1'").first()
@@ -550,7 +541,8 @@ def tables(self, dbName=None):
def tableNames(self, dbName=None):
"""Returns a list of names of tables in the database ``dbName``.
- If ``dbName`` is not specified, the current database will be used.
+ :param dbName: string, name of the database to use. Default to the current database.
+ :return: list of table names, in string
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> "table1" in sqlContext.tableNames()
@@ -585,8 +577,7 @@ def read(self):
Returns a :class:`DataFrameReader` that can be used to read data
in as a :class:`DataFrame`.
- >>> sqlContext.read
-
+ :return: :class:`DataFrameReader`
"""
return DataFrameReader(self)
@@ -644,10 +635,14 @@ def register(self, name, f, returnType=StringType()):
def _test():
+ import os
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.context
+
+ os.chdir(os.environ["SPARK_HOME"])
+
globs = pyspark.sql.context.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 936487519a64..152b87351db3 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -22,6 +22,7 @@
if sys.version >= '3':
basestring = unicode = str
long = int
+ from functools import reduce
else:
from itertools import imap as map
@@ -44,7 +45,7 @@ class DataFrame(object):
A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
and can be created using various functions in :class:`SQLContext`::
- people = sqlContext.parquetFile("...")
+ people = sqlContext.read.parquet("...")
Once created, it can be manipulated using the various domain-specific-language
(DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
@@ -56,8 +57,8 @@ class DataFrame(object):
A more concrete example::
# To create DataFrame using SQLContext
- people = sqlContext.parquetFile("...")
- department = sqlContext.parquetFile("...")
+ people = sqlContext.read.parquet("...")
+ department = sqlContext.read.parquet("...")
people.filter(people.age > 30).join(department, people.deptId == department.id)) \
.groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
@@ -120,21 +121,12 @@ def toJSON(self, use_unicode=True):
rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
- @since(1.3)
def saveAsParquetFile(self, path):
"""Saves the contents as a Parquet file, preserving the schema.
- Files that are written out using this method can be read back in as
- a :class:`DataFrame` using :func:`SQLContext.parquetFile`.
-
- >>> import tempfile, shutil
- >>> parquetFile = tempfile.mkdtemp()
- >>> shutil.rmtree(parquetFile)
- >>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlContext.parquetFile(parquetFile)
- >>> sorted(df2.collect()) == sorted(df.collect())
- True
+ .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead.
"""
+ warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.")
self._jdf.saveAsParquetFile(path)
@since(1.3)
@@ -151,69 +143,45 @@ def registerTempTable(self, name):
"""
self._jdf.registerTempTable(name)
- @since(1.3)
def registerAsTable(self, name):
- """DEPRECATED: use :func:`registerTempTable` instead"""
- warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
+ """
+ .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead.
+ """
+ warnings.warn("Use registerTempTable instead of registerAsTable.")
self.registerTempTable(name)
- @since(1.3)
def insertInto(self, tableName, overwrite=False):
"""Inserts the contents of this :class:`DataFrame` into the specified table.
- Optionally overwriting any existing data.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead.
"""
+ warnings.warn("insertInto is deprecated. Use write.insertInto() instead.")
self.write.insertInto(tableName, overwrite)
- @since(1.3)
def saveAsTable(self, tableName, source=None, mode="error", **options):
"""Saves the contents of this :class:`DataFrame` to a data source as a table.
- The data source is specified by the ``source`` and a set of ``options``.
- If ``source`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Additionally, mode is used to specify the behavior of the saveAsTable operation when
- table already exists in the data source. There are four modes:
-
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead.
"""
+ warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.")
self.write.saveAsTable(tableName, source, mode, **options)
@since(1.3)
def save(self, path=None, source=None, mode="error", **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
- The data source is specified by the ``source`` and a set of ``options``.
- If ``source`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Additionally, mode is used to specify the behavior of the save operation when
- data already exists in the data source. There are four modes:
-
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead.
"""
+ warnings.warn("insertInto is deprecated. Use write.save() instead.")
return self.write.save(path, source, mode, **options)
@property
@since(1.4)
def write(self):
"""
- Interface for saving the content of the :class:`DataFrame` out
- into external storage.
-
- :return :class:`DataFrameWriter`
+ Interface for saving the content of the :class:`DataFrame` out into external storage.
- .. note:: Experimental
-
- >>> df.write
-
+ :return: :class:`DataFrameWriter`
"""
return DataFrameWriter(self)
@@ -226,7 +194,11 @@ def schema(self):
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
"""
if self._schema is None:
- self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+ try:
+ self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+ except AttributeError as e:
+ raise Exception(
+ "Unable to parse datatype from schema. %s" % e)
return self._schema
@since(1.3)
@@ -536,36 +508,52 @@ def alias(self, alias):
@ignore_unicode_prefix
@since(1.3)
- def join(self, other, joinExprs=None, joinType=None):
+ def join(self, other, on=None, how=None):
"""Joins with another :class:`DataFrame`, using the given join expression.
The following performs a full outer join between ``df1`` and ``df2``.
:param other: Right side of the join
- :param joinExprs: a string for join column name, or a join expression (Column).
- If joinExprs is a string indicating the name of the join column,
- the column must exist on both sides, and this performs an inner equi-join.
- :param joinType: str, default 'inner'.
+ :param on: a string for join column name, a list of column names,
+ , a join expression (Column) or a list of Columns.
+ If `on` is a string or a list of string indicating the name of the join column(s),
+ the column(s) must exist on both sides, and this performs an inner equi-join.
+ :param how: str, default 'inner'.
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
+ >>> cond = [df.name == df3.name, df.age == df3.age]
+ >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
+ [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)]
+
>>> df.join(df2, 'name').select(df.name, df2.height).collect()
[Row(name=u'Bob', height=85)]
+
+ >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()
+ [Row(name=u'Bob', age=5)]
"""
- if joinExprs is None:
+ if on is not None and not isinstance(on, list):
+ on = [on]
+
+ if on is None or len(on) == 0:
jdf = self._jdf.join(other._jdf)
- elif isinstance(joinExprs, basestring):
- jdf = self._jdf.join(other._jdf, joinExprs)
+
+ if isinstance(on[0], basestring):
+ jdf = self._jdf.join(other._jdf, self._jseq(on))
else:
- assert isinstance(joinExprs, Column), "joinExprs should be Column"
- if joinType is None:
- jdf = self._jdf.join(other._jdf, joinExprs._jc)
+ assert isinstance(on[0], Column), "on should be Column or list of Column"
+ if len(on) > 1:
+ on = reduce(lambda x, y: x.__and__(y), on)
+ else:
+ on = on[0]
+ if how is None:
+ jdf = self._jdf.join(other._jdf, on._jc, "inner")
else:
- assert isinstance(joinType, basestring), "joinType should be basestring"
- jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
+ assert isinstance(how, basestring), "how should be basestring"
+ jdf = self._jdf.join(other._jdf, on._jc, how)
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@@ -636,6 +624,9 @@ def describe(self, *cols):
This include count, mean, stddev, min, and max. If no columns are
given, this function computes statistics for all numerical columns.
+ .. note:: This function is meant for exploratory data analysis, as we make no \
+ guarantee about the backward compatibility of the schema of the resulting DataFrame.
+
>>> df.describe().show()
+-------+---+
|summary|age|
@@ -646,16 +637,30 @@ def describe(self, *cols):
| min| 2|
| max| 5|
+-------+---+
+ >>> df.describe(['age', 'name']).show()
+ +-------+---+-----+
+ |summary|age| name|
+ +-------+---+-----+
+ | count| 2| 2|
+ | mean|3.5| null|
+ | stddev|1.5| null|
+ | min| 2|Alice|
+ | max| 5| Bob|
+ +-------+---+-----+
"""
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
jdf = self._jdf.describe(self._jseq(cols))
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@since(1.3)
def head(self, n=None):
- """
- Returns the first ``n`` rows as a list of :class:`Row`,
- or the first :class:`Row` if ``n`` is ``None.``
+ """Returns the first ``n`` rows.
+
+ :param n: int, default 1. Number of rows to return.
+ :return: If n is greater than 1, return a list of :class:`Row`.
+ If n is 1, return a single Row.
>>> df.head()
Row(age=2, name=u'Alice')
@@ -745,7 +750,7 @@ def selectExpr(self, *expr):
This is a variant of :func:`select` that accepts SQL expressions.
>>> df.selectExpr("age * 2", "abs(age)").collect()
- [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
+ [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)]
"""
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0]
@@ -925,8 +930,7 @@ def dropDuplicates(self, subset=None):
@since("1.3.1")
def dropna(self, how='any', thresh=None, subset=None):
"""Returns a new :class:`DataFrame` omitting rows with null values.
-
- This is an alias for ``na.drop()``.
+ :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other.
:param how: 'any' or 'all'.
If 'any', drop a row if it contains any nulls.
@@ -936,13 +940,6 @@ def dropna(self, how='any', thresh=None, subset=None):
This overwrites the `how` parameter.
:param subset: optional list of column names to consider.
- >>> df4.dropna().show()
- +---+------+-----+
- |age|height| name|
- +---+------+-----+
- | 10| 80|Alice|
- +---+------+-----+
-
>>> df4.na.drop().show()
+---+------+-----+
|age|height| name|
@@ -968,6 +965,7 @@ def dropna(self, how='any', thresh=None, subset=None):
@since("1.3.1")
def fillna(self, value, subset=None):
"""Replace null values, alias for ``na.fill()``.
+ :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other.
:param value: int, long, float, string, or dict.
Value to replace null values with.
@@ -979,7 +977,7 @@ def fillna(self, value, subset=None):
For example, if `value` is a string, and subset contains a non-string column,
then the non-string column is simply ignored.
- >>> df4.fillna(50).show()
+ >>> df4.na.fill(50).show()
+---+------+-----+
|age|height| name|
+---+------+-----+
@@ -989,16 +987,6 @@ def fillna(self, value, subset=None):
| 50| 50| null|
+---+------+-----+
- >>> df4.fillna({'age': 50, 'name': 'unknown'}).show()
- +---+------+-------+
- |age|height| name|
- +---+------+-------+
- | 10| 80| Alice|
- | 5| null| Bob|
- | 50| null| Tom|
- | 50| null|unknown|
- +---+------+-------+
-
>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
+---+------+-------+
|age|height| name|
@@ -1030,6 +1018,8 @@ def fillna(self, value, subset=None):
@since(1.4)
def replace(self, to_replace, value, subset=None):
"""Returns a new :class:`DataFrame` replacing a value with another value.
+ :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
+ aliases of each other.
:param to_replace: int, long, float, string, or list.
Value to be replaced.
@@ -1045,7 +1035,7 @@ def replace(self, to_replace, value, subset=None):
For example, if `value` is a string, and subset contains a non-string column,
then the non-string column is simply ignored.
- >>> df4.replace(10, 20).show()
+ >>> df4.na.replace(10, 20).show()
+----+------+-----+
| age|height| name|
+----+------+-----+
@@ -1055,7 +1045,7 @@ def replace(self, to_replace, value, subset=None):
|null| null| null|
+----+------+-----+
- >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
+ >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
+----+------+----+
| age|height|name|
+----+------+----+
@@ -1106,9 +1096,9 @@ def replace(self, to_replace, value, subset=None):
@since(1.4)
def corr(self, col1, col2, method=None):
"""
- Calculates the correlation of two columns of a DataFrame as a double value. Currently only
- supports the Pearson Correlation Coefficient.
- :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases.
+ Calculates the correlation of two columns of a DataFrame as a double value.
+ Currently only supports the Pearson Correlation Coefficient.
+ :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other.
:param col1: The name of the first column
:param col2: The name of the second column
@@ -1170,6 +1160,9 @@ def freqItems(self, cols, support=None):
"http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou".
:func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases.
+ .. note:: This function is meant for exploratory data analysis, as we make no \
+ guarantee about the backward compatibility of the schema of the resulting DataFrame.
+
:param cols: Names of the columns to calculate frequent items for as a list or tuple of
strings.
:param support: The frequency with which to consider an item 'frequent'. Default is 1%.
@@ -1214,15 +1207,30 @@ def withColumnRenamed(self, existing, new):
@since(1.4)
@ignore_unicode_prefix
- def drop(self, colName):
+ def drop(self, col):
"""Returns a new :class:`DataFrame` that drops the specified column.
- :param colName: string, name of the column to drop.
+ :param col: a string name of the column to drop, or a
+ :class:`Column` to drop.
>>> df.drop('age').collect()
[Row(name=u'Alice'), Row(name=u'Bob')]
+
+ >>> df.drop(df.age).collect()
+ [Row(name=u'Alice'), Row(name=u'Bob')]
+
+ >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect()
+ [Row(age=5, height=85, name=u'Bob')]
+
+ >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect()
+ [Row(age=5, name=u'Bob', height=85)]
"""
- jdf = self._jdf.drop(colName)
+ if isinstance(col, basestring):
+ jdf = self._jdf.drop(col)
+ elif isinstance(col, Column):
+ jdf = self._jdf.drop(col._jc)
+ else:
+ raise TypeError("col should be a string or a Column")
return DataFrame(jdf, self.sql_ctx)
@since(1.3)
@@ -1239,7 +1247,10 @@ def toPandas(self):
import pandas as pd
return pd.DataFrame.from_records(self.collect(), columns=self.columns)
+ ##########################################################################################
# Pandas compatibility
+ ##########################################################################################
+
groupby = groupBy
drop_duplicates = dropDuplicates
@@ -1259,6 +1270,8 @@ def _to_scala_map(sc, jm):
class DataFrameNaFunctions(object):
"""Functionality for working with missing data in :class:`DataFrame`.
+
+ .. versionadded:: 1.4
"""
def __init__(self, df):
@@ -1274,9 +1287,16 @@ def fill(self, value, subset=None):
fill.__doc__ = DataFrame.fillna.__doc__
+ def replace(self, to_replace, value, subset=None):
+ return self.df.replace(to_replace, value, subset)
+
+ replace.__doc__ = DataFrame.replace.__doc__
+
class DataFrameStatFunctions(object):
"""Functionality for statistic functions with :class:`DataFrame`.
+
+ .. versionadded:: 1.4
"""
def __init__(self, df):
@@ -1316,6 +1336,8 @@ def _test():
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2),
+ Row(name='Bob', age=5)]).toDF()
globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
Row(name='Bob', age=5, height=None),
Row(name='Tom', age=None, height=None),
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b6fd413bec7d..f036644acc96 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -43,6 +43,44 @@ def _df(self, jdf):
from pyspark.sql.dataframe import DataFrame
return DataFrame(jdf, self._sqlContext)
+ @since(1.4)
+ def format(self, source):
+ """Specifies the input data source format.
+
+ :param source: string, name of the data source, e.g. 'json', 'parquet'.
+
+ >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json')
+ >>> df.dtypes
+ [('age', 'bigint'), ('name', 'string')]
+
+ """
+ self._jreader = self._jreader.format(source)
+ return self
+
+ @since(1.4)
+ def schema(self, schema):
+ """Specifies the input schema.
+
+ Some data sources (e.g. JSON) can infer the input schema automatically from data.
+ By specifying the schema here, the underlying data source can skip the schema
+ inference step, and thus speed up data loading.
+
+ :param schema: a StructType object
+ """
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+ self._jreader = self._jreader.schema(jschema)
+ return self
+
+ @since(1.4)
+ def options(self, **options):
+ """Adds input options for the underlying data source.
+ """
+ for k in options:
+ self._jreader = self._jreader.option(k, options[k])
+ return self
+
@since(1.4)
def load(self, path=None, format=None, schema=None, **options):
"""Loads data from a data source and returns it as a :class`DataFrame`.
@@ -51,21 +89,20 @@ def load(self, path=None, format=None, schema=None, **options):
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
+
+ >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned')
+ >>> df.dtypes
+ [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
- jreader = self._jreader
if format is not None:
- jreader = jreader.format(format)
+ self.format(format)
if schema is not None:
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
- jreader = jreader.schema(jschema)
- for k in options:
- jreader = jreader.option(k, options[k])
+ self.schema(schema)
+ self.options(**options)
if path is not None:
- return self._df(jreader.load(path))
+ return self._df(self._jreader.load(path))
else:
- return self._df(jreader.load())
+ return self._df(self._jreader.load())
@since(1.4)
def json(self, path, schema=None):
@@ -79,47 +116,25 @@ def json(self, path, schema=None):
:param path: string, path to the JSON dataset.
:param schema: an optional :class:`StructType` for the input schema.
- >>> import tempfile, shutil
- >>> jsonFile = tempfile.mkdtemp()
- >>> shutil.rmtree(jsonFile)
- >>> with open(jsonFile, 'w') as f:
- ... f.writelines(jsonStrings)
- >>> df1 = sqlContext.read.json(jsonFile)
- >>> df1.printSchema()
- root
- |-- field1: long (nullable = true)
- |-- field2: string (nullable = true)
- |-- field3: struct (nullable = true)
- | |-- field4: long (nullable = true)
-
- >>> from pyspark.sql.types import *
- >>> schema = StructType([
- ... StructField("field2", StringType()),
- ... StructField("field3",
- ... StructType([StructField("field5", ArrayType(IntegerType()))]))])
- >>> df2 = sqlContext.read.json(jsonFile, schema)
- >>> df2.printSchema()
- root
- |-- field2: string (nullable = true)
- |-- field3: struct (nullable = true)
- | |-- field5: array (nullable = true)
- | | |-- element: integer (containsNull = true)
+ >>> df = sqlContext.read.json('python/test_support/sql/people.json')
+ >>> df.dtypes
+ [('age', 'bigint'), ('name', 'string')]
+
"""
- if schema is None:
- jdf = self._jreader.json(path)
- else:
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
- jdf = self._jreader.schema(jschema).json(path)
- return self._df(jdf)
+ if schema is not None:
+ self.schema(schema)
+ return self._df(self._jreader.json(path))
@since(1.4)
def table(self, tableName):
"""Returns the specified table as a :class:`DataFrame`.
- >>> sqlContext.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlContext.read.table("table1")
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
+ :param tableName: string, name of the table.
+
+ >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+ >>> df.registerTempTable('tmpTable')
+ >>> sqlContext.read.table('tmpTable').dtypes
+ [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.table(tableName))
@@ -127,13 +142,9 @@ def table(self, tableName):
def parquet(self, *path):
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
- >>> import tempfile, shutil
- >>> parquetFile = tempfile.mkdtemp()
- >>> shutil.rmtree(parquetFile)
- >>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlContext.read.parquet(parquetFile)
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
+ >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+ >>> df.dtypes
+ [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path)))
@@ -195,40 +206,88 @@ def __init__(self, df):
self._jwrite = df._jdf.write()
@since(1.4)
- def save(self, path=None, format=None, mode="error", **options):
- """
- Saves the contents of the :class:`DataFrame` to a data source.
+ def mode(self, saveMode):
+ """Specifies the behavior when data or table already exists.
- The data source is specified by the ``format`` and a set of ``options``.
- If ``format`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Additionally, mode is used to specify the behavior of the save operation when
- data already exists in the data source. There are four modes:
+ Options include:
* `append`: Append contents of this :class:`DataFrame` to existing data.
* `overwrite`: Overwrite existing data.
* `error`: Throw an exception if data already exists.
* `ignore`: Silently ignore this operation if data already exists.
+ >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
+ """
+ self._jwrite = self._jwrite.mode(saveMode)
+ return self
+
+ @since(1.4)
+ def format(self, source):
+ """Specifies the underlying output data source.
+
+ :param source: string, name of the data source, e.g. 'json', 'parquet'.
+
+ >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data'))
+ """
+ self._jwrite = self._jwrite.format(source)
+ return self
+
+ @since(1.4)
+ def options(self, **options):
+ """Adds output options for the underlying data source.
+ """
+ for k in options:
+ self._jwrite = self._jwrite.option(k, options[k])
+ return self
+
+ @since(1.4)
+ def partitionBy(self, *cols):
+ """Partitions the output by the given columns on the file system.
+
+ If specified, the output is laid out on the file system similar
+ to Hive's partitioning scheme.
+
+ :param cols: name of columns
+
+ >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
+ cols = cols[0]
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+ return self
+
+ @since(1.4)
+ def save(self, path=None, format=None, mode="error", **options):
+ """Saves the contents of the :class:`DataFrame` to a data source.
+
+ The data source is specified by the ``format`` and a set of ``options``.
+ If ``format`` is not specified, the default data source configured by
+ ``spark.sql.sources.default`` will be used.
+
:param path: the path in a Hadoop supported file system
:param format: the format used to save
- :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ :param mode: specifies the behavior of the save operation when data already exists.
+
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
:param options: all other string options
+
+ >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- jwrite = self._jwrite.mode(mode)
+ self.mode(mode).options(**options)
if format is not None:
- jwrite = jwrite.format(format)
- for k in options:
- jwrite = jwrite.option(k, options[k])
+ self.format(format)
if path is None:
- jwrite.save()
+ self._jwrite.save()
else:
- jwrite.save(path)
+ self._jwrite.save(path)
+ @since(1.4)
def insertInto(self, tableName, overwrite=False):
- """
- Inserts the content of the :class:`DataFrame` to the specified table.
+ """Inserts the content of the :class:`DataFrame` to the specified table.
+
It requires that the schema of the class:`DataFrame` is the same as the
schema of the table.
@@ -238,8 +297,7 @@ def insertInto(self, tableName, overwrite=False):
@since(1.4)
def saveAsTable(self, name, format=None, mode="error", **options):
- """
- Saves the content of the :class:`DataFrame` as the specified table.
+ """Saves the content of the :class:`DataFrame` as the specified table.
In the case the table already exists, behavior of this function depends on the
save mode, specified by the `mode` function (default to throwing an exception).
@@ -256,72 +314,61 @@ def saveAsTable(self, name, format=None, mode="error", **options):
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param options: all other string options
"""
- jwrite = self._jwrite.mode(mode)
+ self.mode(mode).options(**options)
if format is not None:
- jwrite = jwrite.format(format)
- for k in options:
- jwrite = jwrite.option(k, options[k])
- return jwrite.saveAsTable(name)
+ self.format(format)
+ self._jwrite.saveAsTable(name)
@since(1.4)
def json(self, path, mode="error"):
- """
- Saves the content of the :class:`DataFrame` in JSON format at the
- specified path.
+ """Saves the content of the :class:`DataFrame` in JSON format at the specified path.
- Additionally, mode is used to specify the behavior of the save operation when
- data already exists in the data source. There are four modes:
+ :param path: the path in any Hadoop supported file system
+ :param mode: specifies the behavior of the save operation when data already exists.
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
- :param path: the path in any Hadoop supported file system
- :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- return self._jwrite.mode(mode).json(path)
+ self._jwrite.mode(mode).json(path)
@since(1.4)
def parquet(self, path, mode="error"):
- """
- Saves the content of the :class:`DataFrame` in Parquet format at the
- specified path.
+ """Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
- Additionally, mode is used to specify the behavior of the save operation when
- data already exists in the data source. There are four modes:
+ :param path: the path in any Hadoop supported file system
+ :param mode: specifies the behavior of the save operation when data already exists.
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
- :param path: the path in any Hadoop supported file system
- :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- return self._jwrite.mode(mode).parquet(path)
+ self._jwrite.mode(mode).parquet(path)
@since(1.4)
def jdbc(self, url, table, mode="error", properties={}):
- """
- Saves the content of the :class:`DataFrame` to a external database table
- via JDBC.
-
- In the case the table already exists in the external database,
- behavior of this function depends on the save mode, specified by the `mode`
- function (default to throwing an exception). There are four modes:
+ """Saves the content of the :class:`DataFrame` to a external database table via JDBC.
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ .. note:: Don't create too many partitions in parallel on a large cluster;\
+ otherwise Spark might crash your external database systems.
- :param url: a JDBC URL of the form `jdbc:subprotocol:subname`
+ :param url: a JDBC URL of the form ``jdbc:subprotocol:subname``
:param table: Name of the table in the external database.
- :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ :param mode: specifies the behavior of the save operation when data already exists.
+
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
:param properties: JDBC database connection arguments, a list of
- arbitrary string tag/value. Normally at least a
- "user" and "password" property should be included.
+ arbitrary string tag/value. Normally at least a
+ "user" and "password" property should be included.
"""
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
for k in properties:
@@ -331,24 +378,23 @@ def jdbc(self, url, table, mode="error", properties={}):
def _test():
import doctest
+ import os
+ import tempfile
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.readwriter
+
+ os.chdir(os.environ["SPARK_HOME"])
+
globs = pyspark.sql.readwriter.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
+
+ globs['tempfile'] = tempfile
+ globs['os'] = os
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
- globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
- .toDF(StructType([StructField('age', IntegerType()),
- StructField('name', StringType())]))
- jsonStrings = [
- '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
- '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
- '"field6":[{"field7": "row2"}]}',
- '{"field1" : null, "field2": "row3", '
- '"field3":{"field4":33, "field5": []}}'
- ]
- globs['jsonStrings'] = jsonStrings
+ globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
+
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 5c53c3a8ed4f..b5fbb7d09882 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -26,6 +26,7 @@
import tempfile
import pickle
import functools
+import time
import datetime
import py4j
@@ -47,6 +48,20 @@
from pyspark.sql.window import Window
+class UTC(datetime.tzinfo):
+ """UTC"""
+ ZERO = datetime.timedelta(0)
+
+ def utcoffset(self, dt):
+ return self.ZERO
+
+ def tzname(self, dt):
+ return "UTC"
+
+ def dst(self, dt):
+ return self.ZERO
+
+
class ExamplePointUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
@@ -100,6 +115,15 @@ def test_data_type_eq(self):
lt2 = pickle.loads(pickle.dumps(LongType()))
self.assertEquals(lt, lt2)
+ # regression test for SPARK-7978
+ def test_decimal_type(self):
+ t1 = DecimalType()
+ t2 = DecimalType(10, 2)
+ self.assertTrue(t2 is not t1)
+ self.assertNotEqual(t1, t2)
+ t3 = DecimalType(8)
+ self.assertNotEqual(t2, t3)
+
class SQLTests(ReusedPySparkTestCase):
@@ -122,6 +146,8 @@ def test_range(self):
self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
+ self.assertEqual(self.sqlCtx.range(-2).count(), 0)
+ self.assertEqual(self.sqlCtx.range(3).count(), 3)
def test_explode(self):
from pyspark.sql.functions import explode
@@ -577,6 +603,23 @@ def test_filter_with_datetime(self):
self.assertEqual(0, df.filter(df.date > date).count())
self.assertEqual(0, df.filter(df.time > time).count())
+ def test_time_with_timezone(self):
+ day = datetime.date.today()
+ now = datetime.datetime.now()
+ ts = time.mktime(now.timetuple()) + now.microsecond / 1e6
+ # class in __main__ is not serializable
+ from pyspark.sql.tests import UTC
+ utc = UTC()
+ utcnow = datetime.datetime.fromtimestamp(ts, utc)
+ df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
+ day1, now1, utcnow1 = df.first()
+ # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version
+ self.assertEqual(day1.date(), day)
+ # Pyrolite does not support microsecond, the error should be
+ # less than 1 millisecond
+ self.assertTrue(now - now1 < datetime.timedelta(0.001))
+ self.assertTrue(now - utcnow1 < datetime.timedelta(0.001))
+
def test_dropna(self):
schema = StructType([
StructField("name", StringType(), True),
@@ -744,8 +787,10 @@ def setUpClass(cls):
try:
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
except py4j.protocol.Py4JError:
+ cls.tearDownClass()
raise unittest.SkipTest("Hive is not available")
except TypeError:
+ cls.tearDownClass()
raise unittest.SkipTest("Hive is not available")
os.unlink(cls.tempdir.name)
_scala_HiveContext =\
diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py
similarity index 97%
rename from python/pyspark/sql/_types.py
rename to python/pyspark/sql/types.py
index 9e7e9f04bc35..23d9adb0daea 100644
--- a/python/pyspark/sql/_types.py
+++ b/python/pyspark/sql/types.py
@@ -19,6 +19,7 @@
import decimal
import time
import datetime
+import calendar
import keyword
import warnings
import json
@@ -97,8 +98,6 @@ class AtomicType(DataType):
"""An internal type used to represent everything that is not
null, UDTs, arrays, structs, and maps."""
- __metaclass__ = DataTypeSingleton
-
class NumericType(AtomicType):
"""Numeric data types.
@@ -109,6 +108,8 @@ class IntegralType(NumericType):
"""Integral data types.
"""
+ __metaclass__ = DataTypeSingleton
+
class FractionalType(NumericType):
"""Fractional data types.
@@ -119,26 +120,36 @@ class StringType(AtomicType):
"""String data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class BinaryType(AtomicType):
"""Binary (byte array) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class BooleanType(AtomicType):
"""Boolean data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class DateType(AtomicType):
"""Date (datetime.date) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class TimestampType(AtomicType):
"""Timestamp (datetime.datetime) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.
@@ -172,11 +183,15 @@ class DoubleType(FractionalType):
"""Double data type, representing double precision floats.
"""
+ __metaclass__ = DataTypeSingleton
+
class FloatType(FractionalType):
"""Float data type, representing single precision floats.
"""
+ __metaclass__ = DataTypeSingleton
+
class ByteType(IntegralType):
"""Byte data type, i.e. a signed integer in a single byte.
@@ -640,10 +655,15 @@ def _need_python_to_sql_conversion(dataType):
_need_python_to_sql_conversion(dataType.valueType)
elif isinstance(dataType, UserDefinedType):
return True
+ elif isinstance(dataType, (DateType, TimestampType)):
+ return True
else:
return False
+EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
+
+
def _python_to_sql_converter(dataType):
"""
Returns a converter that converts a Python object into a SQL datum for the given type.
@@ -681,18 +701,32 @@ def converter(obj):
return tuple(c(d.get(n)) for n, c in zip(names, converters))
else:
return tuple(c(v) for c, v in zip(converters, obj))
- else:
+ elif obj is not None:
raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
return converter
elif isinstance(dataType, ArrayType):
element_converter = _python_to_sql_converter(dataType.elementType)
- return lambda a: [element_converter(v) for v in a]
+ return lambda a: a and [element_converter(v) for v in a]
elif isinstance(dataType, MapType):
key_converter = _python_to_sql_converter(dataType.keyType)
value_converter = _python_to_sql_converter(dataType.valueType)
- return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+ return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+
elif isinstance(dataType, UserDefinedType):
- return lambda obj: dataType.serialize(obj)
+ return lambda obj: obj and dataType.serialize(obj)
+
+ elif isinstance(dataType, DateType):
+ return lambda d: d and d.toordinal() - EPOCH_ORDINAL
+
+ elif isinstance(dataType, TimestampType):
+
+ def to_posix_timstamp(dt):
+ if dt:
+ seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
+ else time.mktime(dt.timetuple()))
+ return int(seconds * 1e7 + dt.microsecond * 10)
+ return to_posix_timstamp
+
else:
raise ValueError("Unexpected type %r" % dataType)
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index 0a0e006bdf83..c74745c726a0 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -32,7 +32,6 @@ def _to_java_cols(cols):
class Window(object):
-
"""
Utility functions for defining window in DataFrames.
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 33ea8c9293d7..57049beea4db 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -41,8 +41,8 @@
class PySparkStreamingTestCase(unittest.TestCase):
- timeout = 4 # seconds
- duration = .2
+ timeout = 10 # seconds
+ duration = .5
@classmethod
def setUpClass(cls):
@@ -379,13 +379,13 @@ def func(dstream):
class WindowFunctionTests(PySparkStreamingTestCase):
- timeout = 5
+ timeout = 15
def test_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
- return dstream.window(.6, .2).count()
+ return dstream.window(1.5, .5).count()
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
@@ -394,7 +394,7 @@ def test_count_by_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
- return dstream.countByWindow(.6, .2)
+ return dstream.countByWindow(1.5, .5)
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
@@ -403,7 +403,7 @@ def test_count_by_window_large(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
- return dstream.countByWindow(1, .2)
+ return dstream.countByWindow(2.5, .5)
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
self._test_func(input, func, expected)
@@ -412,7 +412,7 @@ def test_count_by_value_and_window(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
- return dstream.countByValueAndWindow(1, .2)
+ return dstream.countByValueAndWindow(2.5, .5)
expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
self._test_func(input, func, expected)
@@ -421,7 +421,7 @@ def test_group_by_key_and_window(self):
input = [[('a', i)] for i in range(5)]
def func(dstream):
- return dstream.groupByKeyAndWindow(.6, .2).mapValues(list)
+ return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list)
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
@@ -615,7 +615,6 @@ def test_kafka_stream(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
"test-streaming-consumer", {topic: 1},
@@ -631,7 +630,6 @@ def test_kafka_direct_stream(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
self._validateStreamResult(sendData, stream)
@@ -646,7 +644,6 @@ def test_kafka_direct_stream_from_offset(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets)
self._validateStreamResult(sendData, stream)
@@ -661,7 +658,6 @@ def test_kafka_rdd(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
self._validateRddResult(sendData, rdd)
@@ -677,7 +673,6 @@ def test_kafka_rdd_with_leaders(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
self._validateRddResult(sendData, rdd)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f9fb37f7fc13..11b402e6df6c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -458,6 +458,14 @@ def test_id(self):
self.assertEqual(id + 1, id2)
self.assertEqual(id2, rdd2.id())
+ def test_empty_rdd(self):
+ rdd = self.sc.emptyRDD()
+ self.assertTrue(rdd.isEmpty())
+
+ def test_sum(self):
+ self.assertEqual(0, self.sc.emptyRDD().sum())
+ self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
+
def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
diff --git a/python/run-tests b/python/run-tests
index ffde2fb24b36..4468fdb3f267 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -57,54 +57,57 @@ function run_test() {
function run_core_tests() {
echo "Run core tests ..."
- run_test "pyspark/rdd.py"
- run_test "pyspark/context.py"
- run_test "pyspark/conf.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
- run_test "pyspark/serializers.py"
- run_test "pyspark/profiler.py"
- run_test "pyspark/shuffle.py"
- run_test "pyspark/tests.py"
+ run_test "pyspark.rdd"
+ run_test "pyspark.context"
+ run_test "pyspark.conf"
+ run_test "pyspark.broadcast"
+ run_test "pyspark.accumulators"
+ run_test "pyspark.serializers"
+ run_test "pyspark.profiler"
+ run_test "pyspark.shuffle"
+ run_test "pyspark.tests"
}
function run_sql_tests() {
echo "Run sql tests ..."
- run_test "pyspark/sql/_types.py"
- run_test "pyspark/sql/context.py"
- run_test "pyspark/sql/column.py"
- run_test "pyspark/sql/dataframe.py"
- run_test "pyspark/sql/group.py"
- run_test "pyspark/sql/functions.py"
- run_test "pyspark/sql/tests.py"
+ run_test "pyspark.sql.types"
+ run_test "pyspark.sql.context"
+ run_test "pyspark.sql.column"
+ run_test "pyspark.sql.dataframe"
+ run_test "pyspark.sql.group"
+ run_test "pyspark.sql.functions"
+ run_test "pyspark.sql.readwriter"
+ run_test "pyspark.sql.window"
+ run_test "pyspark.sql.tests"
}
function run_mllib_tests() {
echo "Run mllib tests ..."
- run_test "pyspark/mllib/classification.py"
- run_test "pyspark/mllib/clustering.py"
- run_test "pyspark/mllib/evaluation.py"
- run_test "pyspark/mllib/feature.py"
- run_test "pyspark/mllib/fpm.py"
- run_test "pyspark/mllib/linalg.py"
- run_test "pyspark/mllib/rand.py"
- run_test "pyspark/mllib/recommendation.py"
- run_test "pyspark/mllib/regression.py"
- run_test "pyspark/mllib/stat/_statistics.py"
- run_test "pyspark/mllib/tree.py"
- run_test "pyspark/mllib/util.py"
- run_test "pyspark/mllib/tests.py"
+ run_test "pyspark.mllib.classification"
+ run_test "pyspark.mllib.clustering"
+ run_test "pyspark.mllib.evaluation"
+ run_test "pyspark.mllib.feature"
+ run_test "pyspark.mllib.fpm"
+ run_test "pyspark.mllib.linalg"
+ run_test "pyspark.mllib.random"
+ run_test "pyspark.mllib.recommendation"
+ run_test "pyspark.mllib.regression"
+ run_test "pyspark.mllib.stat._statistics"
+ run_test "pyspark.mllib.stat.KernelDensity"
+ run_test "pyspark.mllib.tree"
+ run_test "pyspark.mllib.util"
+ run_test "pyspark.mllib.tests"
}
function run_ml_tests() {
echo "Run ml tests ..."
- run_test "pyspark/ml/feature.py"
- run_test "pyspark/ml/classification.py"
- run_test "pyspark/ml/recommendation.py"
- run_test "pyspark/ml/regression.py"
- run_test "pyspark/ml/tuning.py"
- run_test "pyspark/ml/tests.py"
- run_test "pyspark/ml/evaluation.py"
+ run_test "pyspark.ml.feature"
+ run_test "pyspark.ml.classification"
+ run_test "pyspark.ml.recommendation"
+ run_test "pyspark.ml.regression"
+ run_test "pyspark.ml.tuning"
+ run_test "pyspark.ml.tests"
+ run_test "pyspark.ml.evaluation"
}
function run_streaming_tests() {
@@ -124,8 +127,8 @@ function run_streaming_tests() {
done
export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell"
- run_test "pyspark/streaming/util.py"
- run_test "pyspark/streaming/tests.py"
+ run_test "pyspark.streaming.util"
+ run_test "pyspark.streaming.tests"
}
echo "Running PySpark tests. Output is in python/$LOG_FILE."
diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata
new file mode 100644
index 000000000000..7ef2320651de
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_common_metadata differ
diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata
new file mode 100644
index 000000000000..78a1ca7d3827
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_metadata differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc
new file mode 100644
index 000000000000..e93f42ed6f35
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet
new file mode 100644
index 000000000000..461c382937ec
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc
new file mode 100644
index 000000000000..b63c4d6d1e1d
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc
new file mode 100644
index 000000000000..5bc0ebd71356
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet
new file mode 100644
index 000000000000..62a63915beac
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet
new file mode 100644
index 000000000000..67665a7b55da
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc
new file mode 100644
index 000000000000..ae94a15d08c8
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet
new file mode 100644
index 000000000000..6cb8538aa890
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc
new file mode 100644
index 000000000000..58d9bb5fc588
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet
new file mode 100644
index 000000000000..9b00805481e7
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet differ
diff --git a/python/test_support/sql/people.json b/python/test_support/sql/people.json
new file mode 100644
index 000000000000..50a859cbd7ee
--- /dev/null
+++ b/python/test_support/sql/people.json
@@ -0,0 +1,3 @@
+{"name":"Michael"}
+{"name":"Andy", "age":30}
+{"name":"Justin", "age":19}
diff --git a/repl/pom.xml b/repl/pom.xml
index 03053b4c3b28..85f7bc8ac102 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -48,6 +48,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-bagel_${scala.binary.version}
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 934daaeaafca..50fd43a418bc 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -22,13 +22,12 @@ import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.util.Utils
-class ReplSuite extends FunSuite {
+class ReplSuite extends SparkFunSuite {
def runInterpreter(master: String, input: String): String = {
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 14f5e9ed4f25..9ecc7c229e38 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -24,14 +24,13 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.tools.nsc.interpreter.SparkILoop
-import org.scalatest.FunSuite
import org.apache.commons.lang3.StringEscapeUtils
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils
-class ReplSuite extends FunSuite {
+class ReplSuite extends SparkFunSuite {
def runInterpreter(master: String, input: String): String = {
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
index c709cde74074..a58eda12b112 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
@@ -25,7 +25,6 @@ import scala.language.implicitConversions
import scala.language.postfixOps
import org.scalatest.BeforeAndAfterAll
-import org.scalatest.FunSuite
import org.scalatest.concurrent.Interruptor
import org.scalatest.concurrent.Timeouts._
import org.scalatest.mock.MockitoSugar
@@ -35,7 +34,7 @@ import org.apache.spark._
import org.apache.spark.util.Utils
class ExecutorClassLoaderSuite
- extends FunSuite
+ extends SparkFunSuite
with BeforeAndAfterAll
with MockitoSugar
with Logging {
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 68d980b610c0..d6f927b6fa80 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -14,25 +14,41 @@
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
-
-
-
-
-
-
+
- Scalastyle standard configuration
-
-
-
-
-
-
-
-
- Scalastyle standard configuration
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
- true
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
- ARROW, EQUALS
+ ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW
-
+
+
- ARROW, EQUALS, COMMA, COLON, IF, WHILE, FOR
+ ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW
+
+
+
+
+
+ ^FunSuite[A-Za-z]*$
+ Tests must extend org.apache.spark.SparkFunSuite instead.
+
+
+
+
+
+
+
+
+ ^println$
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 800>
+
+
+
+
+ 30
+
+
+
+
+ 10
+
+
+
+
+ 50
+
+
+
+
+
+
+
+
+
+
+ -1,0,1,2,3
+
+
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 5c322d032d47..f4b1cc3a4ffe 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -36,10 +36,6 @@
-
- org.scala-lang
- scala-compiler
- org.scala-langscala-reflect
@@ -50,6 +46,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-unsafe_${scala.binary.version}
@@ -60,6 +63,11 @@
scalacheck_${scala.binary.version}test
+
+ org.codehaus.janino
+ janino
+ 2.7.8
+ target/scala-${scala.binary.version}/classes
@@ -101,13 +109,6 @@
!scala-2.11
-
-
- org.scalamacros
- quasiquotes_${scala.binary.version}
- ${scala.macros.version}
-
-
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java
new file mode 100644
index 000000000000..acec2bf4520f
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql;
+
+import org.apache.spark.sql.catalyst.expressions.MutableRow;
+
+public abstract class BaseMutableRow extends BaseRow implements MutableRow {
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setString(int ordinal, String value) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
new file mode 100644
index 000000000000..611e02d8fb66
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql;
+
+import java.math.BigDecimal;
+import java.sql.Date;
+import java.sql.Timestamp;
+import java.util.List;
+
+import scala.collection.Seq;
+import scala.collection.mutable.ArraySeq;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.types.StructType;
+
+public abstract class BaseRow extends InternalRow {
+
+ @Override
+ final public int length() {
+ return size();
+ }
+
+ @Override
+ public boolean anyNull() {
+ final int n = size();
+ for (int i=0; i < n; i++) {
+ if (isNullAt(i)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public StructType schema() { throw new UnsupportedOperationException(); }
+
+ @Override
+ final public Object apply(int i) {
+ return get(i);
+ }
+
+ @Override
+ public int getInt(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long getLong(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float getFloat(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double getDouble(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte getByte(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public short getShort(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean getBoolean(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String getString(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BigDecimal getDecimal(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Date getDate(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Timestamp getTimestamp(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Seq getSeq(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public List getList(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public scala.collection.Map getMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public scala.collection.immutable.Map getValuesMap(Seq fieldNames) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public java.util.Map getJavaMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row getStruct(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public T getAs(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public T getAs(String fieldName) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int fieldIndex(String name) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * A generic version of Row.equals(Row), which is used for tests.
+ */
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof Row) {
+ Row row = (Row) other;
+ int n = size();
+ if (n != row.size()) {
+ return false;
+ }
+ for (int i = 0; i < n; i ++) {
+ if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public InternalRow copy() {
+ final int n = size();
+ Object[] arr = new Object[n];
+ for (int i = 0; i < n; i++) {
+ arr[i] = get(i);
+ }
+ return new GenericRow(arr);
+ }
+
+ @Override
+ public Seq
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-catalyst_${scala.binary.version}
@@ -54,11 +61,11 @@
test
- com.twitter
+ org.apache.parquetparquet-column
- com.twitter
+ org.apache.parquetparquet-hadoop
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b49b1d327289..d3efa83380d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -716,6 +716,18 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*/
def endsWith(literal: String): Column = this.endsWith(lit(literal))
+ /**
+ * Gives the column an alias. Same as `as`.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".alias("colB"))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 1.4.0
+ */
+ def alias(alias: String): Column = as(alias)
+
/**
* Gives the column an alias.
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index e90109446b64..444916bbadb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import java.util.Properties
-import scala.collection.JavaConversions._
import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
@@ -33,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -57,14 +56,11 @@ private[sql] object DataFrame {
* :: Experimental ::
* A distributed collection of data organized into named columns.
*
- * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways
- * to create a [[DataFrame]]:
+ * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates
+ * a [[DataFrame]] by pointing Spark SQL to a Parquet data set.
* {{{
- * // Create a DataFrame from Parquet files
- * val people = sqlContext.parquetFile("...")
- *
- * // Create a DataFrame from data sources
- * val df = sqlContext.load("...", "json")
+ * val people = sqlContext.read.parquet("...") // in Scala
+ * DataFrame people = sqlContext.read().parquet("...") // in Java
* }}}
*
* Once created, it can be manipulated using the various domain-specific-language (DSL) functions
@@ -86,8 +82,8 @@ private[sql] object DataFrame {
* A more concrete example in Scala:
* {{{
* // To create DataFrame using SQLContext
- * val people = sqlContext.parquetFile("...")
- * val department = sqlContext.parquetFile("...")
+ * val people = sqlContext.read.parquet("...")
+ * val department = sqlContext.read.parquet("...")
*
* people.filter("age > 30")
* .join(department, people("deptId") === department("id"))
@@ -98,8 +94,8 @@ private[sql] object DataFrame {
* and in Java:
* {{{
* // To create DataFrame using SQLContext
- * DataFrame people = sqlContext.parquetFile("...");
- * DataFrame department = sqlContext.parquetFile("...");
+ * DataFrame people = sqlContext.read().parquet("...");
+ * DataFrame department = sqlContext.read().parquet("...");
*
* people.filter("age".gt(30))
* .join(department, people.col("deptId").equalTo(department("id")))
@@ -172,23 +168,34 @@ class DataFrame private[sql](
/**
* Internal API for Python
- * @param numRows Number of rows to show
+ * @param _numRows Number of rows to show
*/
- private[sql] def showString(numRows: Int): String = {
+ private[sql] def showString(_numRows: Int): String = {
+ val numRows = _numRows.max(0)
val sb = new StringBuilder
- val data = take(numRows)
+ val takeResult = take(numRows + 1)
+ val hasMoreData = takeResult.length > numRows
+ val data = takeResult.take(numRows)
val numCols = schema.fieldNames.length
+ // For array values, replace Seq and Array with square brackets
// For cells that are beyond 20 characters, replace it with the first 17 and "..."
val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
row.toSeq.map { cell =>
- val str = if (cell == null) "null" else cell.toString
+ val str = cell match {
+ case null => "null"
+ case array: Array[_] => array.mkString("[", ", ", "]")
+ case seq: Seq[_] => seq.mkString("[", ", ", "]")
+ case _ => cell.toString
+ }
if (str.length > 20) str.substring(0, 17) + "..." else str
}: Seq[String]
}
+ // Initialise the width of each column to a minimum value of '3'
+ val colWidths = Array.fill(numCols)(3)
+
// Compute the width of each column
- val colWidths = Array.fill(numCols)(0)
for (row <- rows) {
for ((cell, i) <- row.zipWithIndex) {
colWidths(i) = math.max(colWidths(i), cell.length)
@@ -200,7 +207,7 @@ class DataFrame private[sql](
// column names
rows.head.zipWithIndex.map { case (cell, i) =>
- StringUtils.leftPad(cell.toString, colWidths(i))
+ StringUtils.leftPad(cell, colWidths(i))
}.addString(sb, "|", "|", "|\n")
sb.append(sep)
@@ -213,6 +220,13 @@ class DataFrame private[sql](
}
sb.append(sep)
+
+ // For Data that has more than "numRows" records
+ if (hasMoreData) {
+ val rowsString = if (numRows == 1) "row" else "rows"
+ sb.append(s"only showing top $numRows ${rowsString}\n")
+ }
+
sb.toString()
}
@@ -398,22 +412,50 @@ class DataFrame private[sql](
* @since 1.4.0
*/
def join(right: DataFrame, usingColumn: String): DataFrame = {
+ join(right, Seq(usingColumn))
+ }
+
+ /**
+ * Inner equi-join with another [[DataFrame]] using the given columns.
+ *
+ * Different from other join functions, the join columns will only appear once in the output,
+ * i.e. similar to SQL's `JOIN USING` syntax.
+ *
+ * {{{
+ * // Joining df1 and df2 using the columns "user_id" and "user_name"
+ * df1.join(df2, Seq("user_id", "user_name"))
+ * }}}
+ *
+ * Note that if you perform a self-join using this function without aliasing the input
+ * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since
+ * there is no way to disambiguate which side of the join you would like to reference.
+ *
+ * @param right Right side of the join operation.
+ * @param usingColumns Names of the columns to join on. This columns must exist on both sides.
+ * @group dfops
+ * @since 1.4.0
+ */
+ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = {
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sqlContext.executePlan(
Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]
- // Project only one of the join column.
- val joinedCol = joined.right.resolve(usingColumn)
+ // Project only one of the join columns.
+ val joinedCols = usingColumns.map(col => joined.right.resolve(col))
+ val condition = usingColumns.map { col =>
+ catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col))
+ }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) =>
+ catalyst.expressions.And(cond, eqTo)
+ }
+
Project(
- joined.output.filterNot(_ == joinedCol),
+ joined.output.filterNot(joinedCols.contains(_)),
Join(
joined.left,
joined.right,
joinType = Inner,
- Some(catalyst.expressions.EqualTo(
- joined.left.resolve(usingColumn),
- joined.right.resolve(usingColumn))))
+ condition)
)
}
@@ -577,7 +619,7 @@ class DataFrame private[sql](
def as(alias: Symbol): DataFrame = as(alias.name)
/**
- * Selects a set of expressions.
+ * Selects a set of column based expressions.
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
@@ -989,7 +1031,8 @@ class DataFrame private[sql](
val names = schema.toAttributes.map(_.name)
val rowFunction =
- f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
+ f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema)
+ .asInstanceOf[InternalRow]))
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))
Generate(generator, join = true, outer = false,
@@ -1015,8 +1058,9 @@ class DataFrame private[sql](
val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
val names = attributes.map(_.name)
- def rowFunction(row: Row): TraversableOnce[Row] = {
- f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
+ def rowFunction(row: Row): TraversableOnce[InternalRow] = {
+ f(row(0).asInstanceOf[A]).map(o =>
+ InternalRow(CatalystTypeConverters.convertToCatalyst(o, dataType)))
}
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)
@@ -1085,6 +1129,22 @@ class DataFrame private[sql](
}
}
+ /**
+ * Returns a new [[DataFrame]] with a column dropped.
+ * This version of drop accepts a Column rather than a name.
+ * This is a no-op if the DataFrame doesn't have a column
+ * with an equivalent expression.
+ * @group dfops
+ * @since 1.4.1
+ */
+ def drop(col: Column): DataFrame = {
+ val attrs = this.logicalPlan.output
+ val colsAfterDrop = attrs.filter { attr =>
+ attr != col.expr
+ }.map(attr => Column(attr))
+ select(colsAfterDrop : _*)
+ }
+
/**
* Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]].
* This is an alias for `distinct`.
@@ -1162,7 +1222,7 @@ class DataFrame private[sql](
val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
- val ret: Seq[Row] = if (outputCols.nonEmpty) {
+ val ret: Seq[InternalRow] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
}
@@ -1171,11 +1231,12 @@ class DataFrame private[sql](
// Pivot the data so each summary is one row
row.grouped(outputCols.size).toSeq.zip(statistics).map {
- case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*)
+ case (aggregation, (statistic, _)) =>
+ InternalRow(statistic :: aggregation.toList: _*)
}
} else {
// If there are no output columns, just output a single column that contains the stats.
- statistics.map { case (name, _) => Row(name) }
+ statistics.map { case (name, _) => InternalRow(name) }
}
// All columns are string type
@@ -1298,7 +1359,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- override def distinct: DataFrame = Distinct(logicalPlan)
+ override def distinct: DataFrame = dropDuplicates()
/**
* @group basic
@@ -1444,7 +1505,9 @@ class DataFrame private[sql](
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
- /** Left here for backward compatibility. */
+ /**
+ * @deprecated As of 1.3.0, replaced by `toDF()`.
+ */
@deprecated("use toDF", "1.3.0")
def toSchemaRDD: DataFrame = this
@@ -1455,6 +1518,7 @@ class DataFrame private[sql](
* given name; if you pass `false`, it will throw if the table already
* exists.
* @group output
+ * @deprecated As of 1.340, replaced by `write().jdbc()`.
*/
@deprecated("Use write.jdbc()", "1.4.0")
def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = {
@@ -1473,6 +1537,7 @@ class DataFrame private[sql](
* the RDD in order via the simple statement
* `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().jdbc()`.
*/
@deprecated("Use write.jdbc()", "1.4.0")
def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = {
@@ -1485,6 +1550,7 @@ class DataFrame private[sql](
* Files that are written out using this method can be read back in as a [[DataFrame]]
* using the `parquetFile` function in [[SQLContext]].
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().parquet()`.
*/
@deprecated("Use write.parquet(path)", "1.4.0")
def saveAsParquetFile(path: String): Unit = {
@@ -1508,6 +1574,7 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`.
*/
@deprecated("Use write.saveAsTable(tableName)", "1.4.0")
def saveAsTable(tableName: String): Unit = {
@@ -1526,6 +1593,7 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`.
*/
@deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0")
def saveAsTable(tableName: String, mode: SaveMode): Unit = {
@@ -1545,6 +1613,7 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`.
*/
@deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0")
def saveAsTable(tableName: String, source: String): Unit = {
@@ -1564,6 +1633,7 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`.
*/
@deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0")
def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = {
@@ -1582,6 +1652,8 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`.
*/
@deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)",
"1.4.0")
@@ -1606,6 +1678,8 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`.
*/
@deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)",
"1.4.0")
@@ -1622,6 +1696,7 @@ class DataFrame private[sql](
* using the default data source configured by spark.sql.sources.default and
* [[SaveMode.ErrorIfExists]] as the save mode.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().save(path)`.
*/
@deprecated("Use write.save(path)", "1.4.0")
def save(path: String): Unit = {
@@ -1632,6 +1707,7 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode,
* using the default data source configured by spark.sql.sources.default.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`.
*/
@deprecated("Use write.mode(mode).save(path)", "1.4.0")
def save(path: String, mode: SaveMode): Unit = {
@@ -1642,6 +1718,7 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame to the given path based on the given data source,
* using [[SaveMode.ErrorIfExists]] as the save mode.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`.
*/
@deprecated("Use write.format(source).save(path)", "1.4.0")
def save(path: String, source: String): Unit = {
@@ -1652,6 +1729,7 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame to the given path based on the given data source and
* [[SaveMode]] specified by mode.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`.
*/
@deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0")
def save(path: String, source: String, mode: SaveMode): Unit = {
@@ -1662,6 +1740,8 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame based on the given data source,
* [[SaveMode]] specified by mode, and a set of options.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().format(source).mode(mode).options(options).save(path)`.
*/
@deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0")
def save(
@@ -1676,6 +1756,8 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame based on the given data source,
* [[SaveMode]] specified by mode, and a set of options
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().format(source).mode(mode).options(options).save(path)`.
*/
@deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0")
def save(
@@ -1689,6 +1771,8 @@ class DataFrame private[sql](
/**
* Adds the rows from this RDD to the specified table, optionally overwriting the existing data.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`.
*/
@deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0")
def insertInto(tableName: String, overwrite: Boolean): Unit = {
@@ -1699,6 +1783,8 @@ class DataFrame private[sql](
* Adds the rows from this RDD to the specified table.
* Throws an exception if the table already exists.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().mode(SaveMode.Append).saveAsTable(tableName)`.
*/
@deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0")
def insertInto(tableName: String): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index b44d4c86ac5d..1828ed1aab50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -245,7 +245,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) {
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(jsonRDD, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(jsonRDD, appliedSchema, columnNameOfCorruptJsonRecord)
- sqlContext.createDataFrame(rowRDD, appliedSchema, needsConversion = false)
+ sqlContext.internalCreateDataFrame(rowRDD, appliedSchema)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 5d106c1ac267..edb9ed7bba56 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -43,7 +43,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
- * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
+ * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @param col1 the name of the column
@@ -97,6 +97,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* The `support` should be greater than 1e-4.
*
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]].
+ *
* @param cols the names of the columns to search frequent items in.
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
* than 1e-4.
@@ -114,6 +117,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* Uses a `default` support of 1%.
*
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]].
+ *
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
@@ -128,6 +134,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
*
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]].
+ *
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
@@ -143,6 +152,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* Uses a `default` support of 1%.
*
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]].
+ *
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 516ba2ac2337..45b3e1bc627d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -40,22 +40,22 @@ private[sql] object GroupedData {
/**
* The Grouping Type
*/
- trait GroupType
+ private[sql] trait GroupType
/**
* To indicate it's the GroupBy
*/
- object GroupByType extends GroupType
+ private[sql] object GroupByType extends GroupType
/**
* To indicate it's the CUBE
*/
- object CubeType extends GroupType
+ private[sql] object CubeType extends GroupType
/**
* To indicate it's the ROLLUP
*/
- object RollupType extends GroupType
+ private[sql] object RollupType extends GroupType
}
/**
@@ -249,7 +249,7 @@ class GroupedData protected[sql](
def mean(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Average)
}
-
+
/**
* Compute the max value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 77c6af27d100..55ab6b3358e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -71,8 +71,12 @@ private[spark] object SQLConf {
// Whether to perform partition discovery when loading external data sources. Default to true.
val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled"
+ // Whether to perform partition column type inference. Default to true.
+ val PARTITION_COLUMN_TYPE_INFERENCE = "spark.sql.sources.partitionColumnTypeInference.enabled"
+
// The output committer class used by FSBasedRelation. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
+ // NOTE: This property should be set in Hadoop `Configuration` rather than Spark `SQLConf`
val OUTPUT_COMMITTER_CLASS = "spark.sql.sources.outputCommitterClass"
// Whether to perform eager analysis when constructing a dataframe.
@@ -157,7 +161,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean
/** When true the planner will use the external sort, which may spill to disk. */
- private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean
+ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "true").toBoolean
/**
* Sort merge join would sort the two side of join first, and then iterate both sides together
@@ -167,15 +171,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean
/**
- * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
+ * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode
* that evaluates expressions found in queries. In general this custom code runs much faster
- * than interpreted evaluation, but there are significant start-up costs due to compilation.
- * As a result codegen is only beneficial when queries run for a long time, or when the same
- * expressions are used multiple times.
- *
- * Defaults to false as this feature is currently experimental.
+ * than interpreted evaluation, but there are some start-up costs (5-10ms) due to compilation.
*/
- private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean
+ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "true").toBoolean
/**
* caseSensitive analysis true by default
@@ -250,6 +250,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def partitionDiscoveryEnabled() =
getConf(SQLConf.PARTITION_DISCOVERY_ENABLED, "true").toBoolean
+ private[spark] def partitionColumnTypeInferenceEnabled() =
+ getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE, "true").toBoolean
+
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
private[spark] def schemaStringLengthThreshold: Int =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a32897c20b47..9d1f89d6d7bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -31,14 +31,13 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.errors.DialectException
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.catalyst.ParserDialect
+import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -120,7 +119,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
// TODO how to handle the temp function per user session?
@transient
- protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf)
+ protected[sql] lazy val functionRegistry: FunctionRegistry =
+ new OverrideFunctionRegistry(FunctionRegistry.builtin)
@transient
protected[sql] lazy val analyzer: Analyzer =
@@ -182,9 +182,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
conf.dialect
}
- sparkContext.getConf.getAll.foreach {
- case (key, value) if key.startsWith("spark.sql") => setConf(key, value)
- case _ =>
+ {
+ // We extract spark sql settings from SparkContext's conf and put them to
+ // Spark SQL's conf.
+ // First, we populate the SQLConf (conf). So, we can make sure that other values using
+ // those settings in their construction can get the correct settings.
+ // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version
+ // and spark.sql.hive.metastore.jars to get correctly constructed.
+ val properties = new Properties
+ sparkContext.getConf.getAll.foreach {
+ case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value)
+ case _ =>
+ }
+ // We directly put those settings to conf to avoid of calling setConf, which may have
+ // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive
+ // get constructed. If we call setConf directly, the constructed metadataHive may have
+ // wrong settings, or the construction may fail.
+ conf.setConf(properties)
+ // After we have populated SQLConf, we call setConf to populate other confs in the subclass
+ // (e.g. hiveconf in HiveContext).
+ properties.foreach {
+ case (key, value) => setConf(key, value)
+ }
}
@transient
@@ -466,14 +485,26 @@ class SQLContext(@transient val sparkContext: SparkContext)
// schema differs from the existing schema on any field data type.
val catalystRows = if (needsConversion) {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
- rowRDD.map(converter(_).asInstanceOf[Row])
+ rowRDD.map(converter(_).asInstanceOf[InternalRow])
} else {
- rowRDD
+ rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)}
}
val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
DataFrame(this, logicalPlan)
}
+ /**
+ * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
+ * converted to Catalyst rows.
+ */
+ private[sql]
+ def internalCreateDataFrame(catalystRows: RDD[InternalRow], schema: StructType) = {
+ // TODO: use MutableProjection when rowRDD is another DataFrame and the applied
+ // schema differs from the existing schema on any field data type.
+ val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
+ DataFrame(this, logicalPlan)
+ }
+
/**
* :: DeveloperApi ::
* Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema.
@@ -511,7 +542,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
extractors.zip(attributeSeq).map { case (e, attr) =>
CatalystTypeConverters.convertToCatalyst(e.invoke(row), attr.dataType)
}.toArray[Any]
- ) : Row
+ ) : InternalRow
}
}
DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this))
@@ -686,7 +717,18 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* :: Experimental ::
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
- * in an range from `start` to `end`(exclusive) with step value 1.
+ * in an range from 0 to `end` (exclusive) with step value 1.
+ *
+ * @since 1.4.1
+ * @group dataframe
+ */
+ @Experimental
+ def range(end: Long): DataFrame = range(0, end)
+
+ /**
+ * :: Experimental ::
+ * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end` (exclusive) with step value 1.
*
* @since 1.4.0
* @group dataframe
@@ -701,7 +743,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* :: Experimental ::
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
- * in an range from `start` to `end`(exclusive) with an step value, with partition number
+ * in an range from `start` to `end` (exclusive) with an step value, with partition number
* specified.
*
* @since 1.4.0
@@ -855,7 +897,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] val planner = new SparkPlanner
@transient
- protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1)
+ protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1)
/**
* Prepares a planned SparkPlan for execution by inserting shuffle operations as needed.
@@ -886,6 +928,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
tlSession.remove()
}
+ protected[sql] def setSession(session: SQLSession): Unit = {
+ detachSession()
+ tlSession.set(session)
+ }
+
protected[sql] class SQLSession {
// Note that this is a lazy val so we can override the default value in subclasses.
protected[sql] lazy val conf: SQLConf = new SQLConf
@@ -917,7 +964,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan)
/** Internal version of the RDD. Avoids copies and has no schema */
- lazy val toRdd: RDD[Row] = executedPlan.execute()
+ lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
@@ -999,7 +1046,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
val rowRdd = convertedRdd.mapPartitions { iter =>
- iter.map { m => new GenericRow(m): Row}
+ iter.map { m => new GenericRow(m): InternalRow}
}
DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
@@ -1021,21 +1068,33 @@ class SQLContext(@transient val sparkContext: SparkContext)
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
+ /**
+ * @deprecated As of 1.3.0, replaced by `createDataFrame()`.
+ */
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD, schema)
}
+ /**
+ * @deprecated As of 1.3.0, replaced by `createDataFrame()`.
+ */
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD, schema)
}
+ /**
+ * @deprecated As of 1.3.0, replaced by `createDataFrame()`.
+ */
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
createDataFrame(rdd, beanClass)
}
+ /**
+ * @deprecated As of 1.3.0, replaced by `createDataFrame()`.
+ */
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
createDataFrame(rdd, beanClass)
@@ -1046,6 +1105,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* [[DataFrame]] if no paths are passed in.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().parquet()`.
*/
@deprecated("Use read.parquet()", "1.4.0")
@scala.annotation.varargs
@@ -1065,6 +1125,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It goes through the entire dataset once to determine the schema.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonFile(path: String): DataFrame = {
@@ -1076,6 +1137,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonFile(path: String, schema: StructType): DataFrame = {
@@ -1084,6 +1146,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonFile(path: String, samplingRatio: Double): DataFrame = {
@@ -1096,6 +1159,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It goes through the entire dataset once to determine the schema.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: RDD[String]): DataFrame = read.json(json)
@@ -1106,6 +1170,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It goes through the entire dataset once to determine the schema.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json)
@@ -1115,6 +1180,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
@@ -1126,6 +1192,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* schema, returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = {
@@ -1137,6 +1204,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* schema, returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
@@ -1148,6 +1216,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* schema, returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = {
@@ -1159,6 +1228,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* using the default data source configured by spark.sql.sources.default.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by `read().load(path)`.
*/
@deprecated("Use read.load(path)", "1.4.0")
def load(path: String): DataFrame = {
@@ -1169,6 +1239,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns the dataset stored at path as a DataFrame, using the given data source.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`.
*/
@deprecated("Use read.format(source).load(path)", "1.4.0")
def load(path: String, source: String): DataFrame = {
@@ -1180,6 +1251,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* a set of options as a DataFrame.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`.
*/
@deprecated("Use read.format(source).options(options).load()", "1.4.0")
def load(source: String, options: java.util.Map[String, String]): DataFrame = {
@@ -1191,6 +1263,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* a set of options as a DataFrame.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`.
*/
@deprecated("Use read.format(source).options(options).load()", "1.4.0")
def load(source: String, options: Map[String, String]): DataFrame = {
@@ -1202,6 +1275,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
* a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by
+ * `read().format(source).schema(schema).options(options).load()`.
*/
@deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0")
def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame =
@@ -1214,6 +1289,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
* a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by
+ * `read().format(source).schema(schema).options(options).load()`.
*/
@deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0")
def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = {
@@ -1225,6 +1302,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* url named table.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().jdbc()`.
*/
@deprecated("use read.jdbc()", "1.4.0")
def jdbc(url: String, table: String): DataFrame = {
@@ -1242,6 +1320,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split
* evenly into this many partitions
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().jdbc()`.
*/
@deprecated("use read.jdbc()", "1.4.0")
def jdbc(
@@ -1261,6 +1340,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* of the [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().jdbc()`.
*/
@deprecated("use read.jdbc()", "1.4.0")
def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 423ecdff5804..43b62f0e822f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -106,7 +106,7 @@ private[r] object SQLUtils {
dfCols.map { col =>
colToRBytes(col)
- }
+ }
}
def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = {
@@ -121,7 +121,7 @@ private[r] object SQLUtils {
val numRows = col.length
val bos = new ByteArrayOutputStream()
val dos = new DataOutputStream(bos)
-
+
SerDe.writeInt(dos, numRows)
col.map { item =>
@@ -139,4 +139,19 @@ private[r] object SQLUtils {
case "ignore" => SaveMode.Ignore
}
}
+
+ def loadDF(
+ sqlContext: SQLContext,
+ source: String,
+ options: java.util.Map[String, String]): DataFrame = {
+ sqlContext.read.format(source).options(options).load()
+ }
+
+ def loadDF(
+ sqlContext: SQLContext,
+ source: String,
+ schema: StructType,
+ options: java.util.Map[String, String]): DataFrame = {
+ sqlContext.read.format(source).schema(schema).options(options).load()
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index aa10af400c81..1949625699ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar
import java.nio.{ByteBuffer, ByteOrder}
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.columnar.ColumnBuilder._
import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder}
import org.apache.spark.sql.types._
@@ -33,7 +33,7 @@ private[sql] trait ColumnBuilder {
/**
* Appends `row(ordinal)` to the column builder.
*/
- def appendFrom(row: Row, ordinal: Int)
+ def appendFrom(row: InternalRow, ordinal: Int)
/**
* Column statistics information
@@ -68,7 +68,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
}
- override def appendFrom(row: Row, ordinal: Int): Unit = {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal))
columnType.append(row, ordinal, buffer)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index b0f983c18067..1bce214d1d6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.columnar
-import java.sql.Timestamp
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)()
@@ -54,7 +53,7 @@ private[sql] sealed trait ColumnStats extends Serializable {
/**
* Gathers statistics information from `row(ordinal)`.
*/
- def gatherStats(row: Row, ordinal: Int): Unit = {
+ def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (row.isNullAt(ordinal)) {
nullCount += 1
// 4 bytes for null position
@@ -67,23 +66,23 @@ private[sql] sealed trait ColumnStats extends Serializable {
* Column statistics represented as a single row, currently including closed lower bound, closed
* upper bound and null count.
*/
- def collectedStatistics: Row
+ def collectedStatistics: InternalRow
}
/**
* A no-op ColumnStats only used for testing purposes.
*/
private[sql] class NoopColumnStats extends ColumnStats {
- override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal)
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal)
- override def collectedStatistics: Row = Row(null, null, nullCount, count, 0L)
+ override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L)
}
private[sql] class BooleanColumnStats extends ColumnStats {
protected var upper = false
protected var lower = true
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getBoolean(ordinal)
@@ -93,14 +92,15 @@ private[sql] class BooleanColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class ByteColumnStats extends ColumnStats {
protected var upper = Byte.MinValue
protected var lower = Byte.MaxValue
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getByte(ordinal)
@@ -110,14 +110,15 @@ private[sql] class ByteColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class ShortColumnStats extends ColumnStats {
protected var upper = Short.MinValue
protected var lower = Short.MaxValue
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getShort(ordinal)
@@ -127,14 +128,15 @@ private[sql] class ShortColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class LongColumnStats extends ColumnStats {
protected var upper = Long.MinValue
protected var lower = Long.MaxValue
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getLong(ordinal)
@@ -144,14 +146,15 @@ private[sql] class LongColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class DoubleColumnStats extends ColumnStats {
protected var upper = Double.MinValue
protected var lower = Double.MaxValue
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getDouble(ordinal)
@@ -161,14 +164,15 @@ private[sql] class DoubleColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class FloatColumnStats extends ColumnStats {
protected var upper = Float.MinValue
protected var lower = Float.MaxValue
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getFloat(ordinal)
@@ -178,14 +182,15 @@ private[sql] class FloatColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class FixedDecimalColumnStats extends ColumnStats {
protected var upper: Decimal = null
protected var lower: Decimal = null
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[Decimal]
@@ -195,14 +200,15 @@ private[sql] class FixedDecimalColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class IntColumnStats extends ColumnStats {
protected var upper = Int.MinValue
protected var lower = Int.MaxValue
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
@@ -212,14 +218,15 @@ private[sql] class IntColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class StringColumnStats extends ColumnStats {
protected var upper: UTF8String = null
protected var lower: UTF8String = null
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[UTF8String]
@@ -229,46 +236,34 @@ private[sql] class StringColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class DateColumnStats extends IntColumnStats
-private[sql] class TimestampColumnStats extends ColumnStats {
- protected var upper: Timestamp = null
- protected var lower: Timestamp = null
-
- override def gatherStats(row: Row, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
- if (!row.isNullAt(ordinal)) {
- val value = row(ordinal).asInstanceOf[Timestamp]
- if (upper == null || value.compareTo(upper) > 0) upper = value
- if (lower == null || value.compareTo(lower) < 0) lower = value
- sizeInBytes += TIMESTAMP.defaultSize
- }
- }
-
- override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
-}
+private[sql] class TimestampColumnStats extends LongColumnStats
private[sql] class BinaryColumnStats extends ColumnStats {
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
sizeInBytes += BINARY.actualSize(row, ordinal)
}
}
- override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(null, null, nullCount, count, sizeInBytes)
}
private[sql] class GenericColumnStats extends ColumnStats {
- override def gatherStats(row: Row, ordinal: Int): Unit = {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
sizeInBytes += GENERIC.actualSize(row, ordinal)
}
}
- override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes)
+ override def collectedStatistics: InternalRow =
+ InternalRow(null, null, nullCount, count, sizeInBytes)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 20be5ca9d004..8e2102091776 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import java.sql.Timestamp
import scala.reflect.runtime.universe.TypeTag
@@ -26,6 +25,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
@@ -321,7 +321,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
val length = buffer.getInt()
val stringBytes = new Array[Byte](length)
buffer.get(stringBytes, 0, length)
- UTF8String(stringBytes)
+ UTF8String.fromBytes(stringBytes)
}
override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
@@ -355,22 +355,20 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) {
}
}
-private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) {
- override def extract(buffer: ByteBuffer): Timestamp = {
- val timestamp = new Timestamp(buffer.getLong())
- timestamp.setNanos(buffer.getInt())
- timestamp
+private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) {
+ override def extract(buffer: ByteBuffer): Long = {
+ buffer.getLong
}
- override def append(v: Timestamp, buffer: ByteBuffer): Unit = {
- buffer.putLong(v.getTime).putInt(v.getNanos)
+ override def append(v: Long, buffer: ByteBuffer): Unit = {
+ buffer.putLong(v)
}
- override def getField(row: Row, ordinal: Int): Timestamp = {
- row(ordinal).asInstanceOf[Timestamp]
+ override def getField(row: Row, ordinal: Int): Long = {
+ row(ordinal).asInstanceOf[Long]
}
- override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = {
+ override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = {
row(ordinal) = value
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 3db26fad2b92..761f427b8cd0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -19,21 +19,16 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import org.apache.spark.{Accumulable, Accumulator, Accumulators}
-import org.apache.spark.sql.catalyst.expressions
-
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
-import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.{Accumulable, Accumulator, Accumulators}
private[sql] object InMemoryRelation {
def apply(
@@ -45,7 +40,7 @@ private[sql] object InMemoryRelation {
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)()
}
-private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row)
+private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow)
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
@@ -56,12 +51,12 @@ private[sql] case class InMemoryRelation(
tableName: Option[String])(
private var _cachedColumnBuffers: RDD[CachedBatch] = null,
private var _statistics: Statistics = null,
- private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null)
+ private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
extends LogicalPlan with MultiInstanceRelation {
- private val batchStats: Accumulable[ArrayBuffer[Row], Row] =
+ private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =
if (_batchStats == null) {
- child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
+ child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow])
} else {
_batchStats
}
@@ -151,7 +146,7 @@ private[sql] case class InMemoryRelation(
rowCount += 1
}
- val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*)
+ val stats = InternalRow.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*)
batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)
@@ -267,7 +262,7 @@ private[sql] case class InMemoryColumnarTableScan(
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
if (enableAccumulators) {
readPartitions.setValue(0)
readBatches.setValue(0)
@@ -296,7 +291,7 @@ private[sql] case class InMemoryColumnarTableScan(
val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
- def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[Row] = {
+ def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[InternalRow] = {
val rows = cacheBatches.flatMap { cachedBatch =>
// Build column accessors
val columnAccessors = requestedColumnIndices.map { batchColumnIndex =>
@@ -306,15 +301,15 @@ private[sql] case class InMemoryColumnarTableScan(
}
// Extract rows via column accessors
- new Iterator[Row] {
+ new Iterator[InternalRow] {
private[this] val rowLen = nextRow.length
- override def next(): Row = {
+ override def next(): InternalRow = {
var i = 0
while (i < rowLen) {
columnAccessors(i).extractTo(nextRow, i)
i += 1
}
- if (attributes.isEmpty) Row.empty else nextRow
+ if (attributes.isEmpty) InternalRow.empty else nextRow
}
override def hasNext: Boolean = columnAccessors(0).hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index f1f494ac26d0..ba47bc783f31 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar
import java.nio.{ByteBuffer, ByteOrder}
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
/**
* A stackable trait used for building byte buffer for a column containing null values. Memory
@@ -52,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
super.initialize(initialSize, columnName, useCompression)
}
- abstract override def appendFrom(row: Row, ordinal: Int): Unit = {
+ abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
columnStats.gatherStats(row, ordinal)
if (row.isNullAt(ordinal)) {
nulls = ColumnBuilder.ensureFreeSpace(nulls, 4)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
index 8e2a1af6dae7..39b21ddb47ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression
import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.Logging
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder}
import org.apache.spark.sql.types.AtomicType
@@ -66,7 +66,7 @@ private[sql] trait CompressibleColumnBuilder[T <: AtomicType]
encoder.compressionRatio < 0.8
}
- private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ private def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {
var i = 0
while (i < compressionEncoders.length) {
compressionEncoders(i).gatherCompressibilityStats(row, ordinal)
@@ -74,7 +74,7 @@ private[sql] trait CompressibleColumnBuilder[T <: AtomicType]
}
}
- abstract override def appendFrom(row: Row, ordinal: Int): Unit = {
+ abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
super.appendFrom(row, ordinal)
if (!row.isNullAt(ordinal)) {
gatherCompressibilityStats(row, ordinal)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
index 17c2d9b11118..4eaec6d853d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -18,14 +18,13 @@
package org.apache.spark.sql.columnar.compression
import java.nio.{ByteBuffer, ByteOrder}
-
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
import org.apache.spark.sql.types.AtomicType
private[sql] trait Encoder[T <: AtomicType] {
- def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {}
+ def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {}
def compressedSize: Int
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index 534ae90ddbc8..5abc1259a19a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -22,8 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.runtimeMirror
-
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.types._
@@ -96,7 +95,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
override def compressedSize: Int = _compressedSize
- override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {
val value = columnType.getField(row, ordinal)
val actualSize = columnType.actualSize(row, ordinal)
_uncompressedSize += actualSize
@@ -217,7 +216,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
// to store dictionary element count.
private var dictionarySize = 4
- override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {
val value = columnType.getField(row, ordinal)
if (!overflow) {
@@ -310,7 +309,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
class Encoder extends compression.Encoder[BooleanType.type] {
private var _uncompressedSize = 0
- override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {
_uncompressedSize += BOOLEAN.defaultSize
}
@@ -404,7 +403,7 @@ private[sql] case object IntDelta extends CompressionScheme {
private var prevValue: Int = _
- override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {
val value = row.getInt(ordinal)
val delta = value - prevValue
@@ -484,7 +483,7 @@ private[sql] case object LongDelta extends CompressionScheme {
private var prevValue: Long = _
- override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {
val value = row.getLong(ordinal)
val delta = value - prevValue
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 8d16749697aa..6e8a5ef18ab6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -20,12 +20,10 @@ package org.apache.spark.sql.execution
import java.util.HashMap
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.SQLContext
/**
* :: DeveloperApi ::
@@ -121,11 +119,11 @@ case class Aggregate(
}
}
- protected override def doExecute(): RDD[Row] = attachTree(this, "execute") {
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
if (groupingExpressions.isEmpty) {
child.execute().mapPartitions { iter =>
val buffer = newAggregateBuffer()
- var currentRow: Row = null
+ var currentRow: InternalRow = null
while (iter.hasNext) {
currentRow = iter.next()
var i = 0
@@ -147,10 +145,10 @@ case class Aggregate(
}
} else {
child.execute().mapPartitions { iter =>
- val hashTable = new HashMap[Row, Array[AggregateFunction]]
+ val hashTable = new HashMap[InternalRow, Array[AggregateFunction]]
val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output)
- var currentRow: Row = null
+ var currentRow: InternalRow = null
while (iter.hasNext) {
currentRow = iter.next()
val currentGroup = groupingProjection(currentRow)
@@ -167,7 +165,7 @@ case class Aggregate(
}
}
- new Iterator[Row] {
+ new Iterator[InternalRow] {
private[this] val hashTableIter = hashTable.entrySet().iterator()
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
@@ -177,7 +175,7 @@ case class Aggregate(
override final def hasNext: Boolean = hashTableIter.hasNext
- override final def next(): Row = {
+ override final def next(): InternalRow = {
val currentEntry = hashTableIter.next()
val currentGroup = currentEntry.getKey
val currentBuffer = currentEntry.getValue
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 5fcc48a67948..a4b38d364d54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -103,7 +103,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
sqlContext.conf.useCompression,
sqlContext.conf.columnBatchSize,
storageLevel,
- query.queryExecution.executedPlan,
+ sqlContext.executePlan(query.logicalPlan).executedPlan,
tableName))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index f25d10fec041..edc64a03335d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -17,29 +17,20 @@
package org.apache.spark.sql.execution
-import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
-import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.util.MutablePair
-
-object Exchange {
- /**
- * Returns true when the ordering expressions are a subset of the key.
- * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]].
- */
- def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
- desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
- }
-}
+import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
/**
* :: DeveloperApi ::
@@ -91,11 +82,7 @@ case class Exchange(
shuffleManager.isInstanceOf[UnsafeShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
- if (newOrdering.nonEmpty) {
- // If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`,
- // which requires a defensive copy.
- true
- } else if (sortBasedShuffleOn) {
+ if (sortBasedShuffleOn) {
val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
// If we're using the original SortShuffleManager and the number of output partitions is
@@ -106,8 +93,11 @@ case class Exchange(
} else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
// SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting
// them. This optimization is guarded by a feature-flag and is only applied in cases where
- // shuffle dependency does not specify an ordering and the record serializer has certain
- // properties. If this optimization is enabled, we can safely avoid the copy.
+ // shuffle dependency does not specify an aggregator or ordering and the record serializer
+ // has certain properties. If this optimization is enabled, we can safely avoid the copy.
+ //
+ // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only
+ // need to check whether the optimization is enabled and supported by our serializer.
//
// This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
false
@@ -118,9 +108,12 @@ case class Exchange(
// both cases, we must copy.
true
}
- } else {
+ } else if (shuffleManager.isInstanceOf[HashShuffleManager]) {
// We're using hash-based shuffle, so we don't need to copy.
false
+ } else {
+ // Catch-all case to safely handle any future ShuffleManager implementations.
+ true
}
}
@@ -143,7 +136,6 @@ case class Exchange(
private def getSerializer(
keySchema: Array[DataType],
valueSchema: Array[DataType],
- hasKeyOrdering: Boolean,
numPartitions: Int): Serializer = {
// It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true.
@@ -159,7 +151,7 @@ case class Exchange(
val serializer = if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.")
- new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
+ new SparkSqlSerializer2(keySchema, valueSchema)
} else {
logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
@@ -168,12 +160,12 @@ case class Exchange(
serializer
}
- protected override def doExecute(): RDD[Row] = attachTree(this , "execute") {
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
- val serializer = getSerializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions)
+ val serializer = getSerializer(keySchema, valueSchema, numPartitions)
val part = new HashPartitioner(numPartitions)
val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
@@ -184,27 +176,24 @@ case class Exchange(
} else {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
- val mutablePair = new MutablePair[Row, Row]()
+ val mutablePair = new MutablePair[InternalRow, InternalRow]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
}
- val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
- if (newOrdering.nonEmpty) {
- shuffled.setKeyOrdering(keyOrdering)
- }
+ val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part)
shuffled.setSerializer(serializer)
shuffled.map(_._2)
case RangePartitioning(sortingExpressions, numPartitions) =>
val keySchema = child.output.map(_.dataType).toArray
- val serializer = getSerializer(keySchema, null, newOrdering.nonEmpty, numPartitions)
+ val serializer = getSerializer(keySchema, null, numPartitions)
val childRdd = child.execute()
val part: Partitioner = {
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
val rddForSampling = childRdd.mapPartitions { iter =>
- val mutablePair = new MutablePair[Row, Null]()
+ val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row.copy(), null))
}
// TODO: RangePartitioner should take an Ordering.
@@ -216,32 +205,31 @@ case class Exchange(
childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))}
} else {
childRdd.mapPartitions { iter =>
- val mutablePair = new MutablePair[Row, Null]()
+ val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row, null))
}
}
- val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
- if (newOrdering.nonEmpty) {
- shuffled.setKeyOrdering(keyOrdering)
- }
+ val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part)
shuffled.setSerializer(serializer)
shuffled.map(_._1)
case SinglePartition =>
val valueSchema = child.output.map(_.dataType).toArray
- val serializer = getSerializer(null, valueSchema, hasKeyOrdering = false, 1)
+ val serializer = getSerializer(null, valueSchema, numPartitions = 1)
val partitioner = new HashPartitioner(1)
val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) {
- child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) }
+ child.execute().mapPartitions {
+ iter => iter.map(r => (null, r.copy()))
+ }
} else {
child.execute().mapPartitions { iter =>
- val mutablePair = new MutablePair[Null, Row]()
+ val mutablePair = new MutablePair[Null, InternalRow]()
iter.map(r => mutablePair.update(null, r))
}
}
- val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
+ val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner)
shuffled.setSerializer(serializer)
shuffled.map(_._2)
@@ -306,29 +294,24 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
child: SparkPlan): SparkPlan = {
val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
val needsShuffle = child.outputPartitioning != partitioning
- val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering)
- if (needSort && needsShuffle && canSortWithShuffle) {
- Exchange(partitioning, rowOrdering, child)
+ val withShuffle = if (needsShuffle) {
+ Exchange(partitioning, Nil, child)
} else {
- val withShuffle = if (needsShuffle) {
- Exchange(partitioning, Nil, child)
- } else {
- child
- }
+ child
+ }
- val withSort = if (needSort) {
- if (sqlContext.conf.externalSortEnabled) {
- ExternalSort(rowOrdering, global = false, withShuffle)
- } else {
- Sort(rowOrdering, global = false, withShuffle)
- }
+ val withSort = if (needSort) {
+ if (sqlContext.conf.externalSortEnabled) {
+ ExternalSort(rowOrdering, global = false, withShuffle)
} else {
- withShuffle
+ Sort(rowOrdering, global = false, withShuffle)
}
-
- withSort
+ } else {
+ withShuffle
}
+
+ withSort
}
if (meetsRequirements && compatible && !needsAnySort) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index f931dc95ef57..da27a753a710 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
@@ -31,7 +31,7 @@ import org.apache.spark.sql.{Row, SQLContext}
*/
@DeveloperApi
object RDDConversions {
- def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[Row] = {
+ def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
data.mapPartitions { iterator =>
val numColumns = outputTypes.length
val mutableRow = new GenericMutableRow(numColumns)
@@ -51,7 +51,7 @@ object RDDConversions {
/**
* Convert the objects inside Row into the types Catalyst expected.
*/
- def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[Row] = {
+ def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[InternalRow] = {
data.mapPartitions { iterator =>
val numColumns = outputTypes.length
val mutableRow = new GenericMutableRow(numColumns)
@@ -70,7 +70,9 @@ object RDDConversions {
}
/** Logical plan node for scanning data from an RDD. */
-private[sql] case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
+private[sql] case class LogicalRDD(
+ output: Seq[Attribute],
+ rdd: RDD[InternalRow])(sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation {
override def children: Seq[LogicalPlan] = Nil
@@ -91,13 +93,15 @@ private[sql] case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlCon
}
/** Physical plan node for scanning data from an RDD. */
-private[sql] case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
- protected override def doExecute(): RDD[Row] = rdd
+private[sql] case class PhysicalRDD(
+ output: Seq[Attribute],
+ rdd: RDD[InternalRow]) extends LeafNode {
+ protected override def doExecute(): RDD[InternalRow] = rdd
}
/** Logical plan node for scanning data from a local collection. */
private[sql]
-case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext)
+case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[InternalRow])(sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation {
override def children: Seq[LogicalPlan] = Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index f16ca36909fa..42a0c1be4f69 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -19,10 +19,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
/**
* Apply the all of the GroupExpressions to every input row, hence we will get
@@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partit
*/
@DeveloperApi
case class Expand(
- projections: Seq[GroupExpression],
+ projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: SparkPlan)
extends UnaryNode {
@@ -43,22 +42,22 @@ case class Expand(
// as UNKNOWN partitioning
override def outputPartitioning: Partitioning = UnknownPartitioning(0)
- protected override def doExecute(): RDD[Row] = attachTree(this, "execute") {
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
// TODO Move out projection objects creation and transfer to
// workers via closure. However we can't assume the Projection
// is serializable because of the code gen, so we have to
// create the projections within each of the partition processing.
- val groups = projections.map(ee => newProjection(ee.children, child.output)).toArray
+ val groups = projections.map(ee => newProjection(ee, child.output)).toArray
- new Iterator[Row] {
- private[this] var result: Row = _
+ new Iterator[InternalRow] {
+ private[this] var result: InternalRow = _
private[this] var idx = -1 // -1 means the initial state
- private[this] var input: Row = _
+ private[this] var input: InternalRow = _
override final def hasNext: Boolean = (-1 < idx && idx < groups.length) || iter.hasNext
- override final def next(): Row = {
+ override final def next(): InternalRow = {
if (idx <= 0) {
// in the initial (-1) or beginning(0) of a new input row, fetch the next input tuple
input = iter.next()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index dd02c1f4573b..c1665f78a960 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -25,12 +25,12 @@ import org.apache.spark.sql.catalyst.expressions._
* For lazy computing, be sure the generator.terminate() called in the very last
* TODO reusing the CompletionIterator?
*/
-private[execution] sealed case class LazyIterator(func: () => TraversableOnce[Row])
- extends Iterator[Row] {
+private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow])
+ extends Iterator[InternalRow] {
lazy val results = func().toIterator
override def hasNext: Boolean = results.hasNext
- override def next(): Row = results.next()
+ override def next(): InternalRow = results.next()
}
/**
@@ -58,11 +58,11 @@ case class Generate(
val boundGenerator = BindReferences.bindReference(generator, child.output)
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
// boundGenerator.terminate() should be triggered after all of the rows in the partition
if (join) {
child.execute().mapPartitions { iter =>
- val generatorNullRow = Row.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
+ val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
val joinedRow = new JoinedRow
iter.flatMap { row =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 2ec7d4fbc92d..ba2c8f53d702 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -66,7 +66,7 @@ case class GeneratedAggregate(
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
a.collect { case agg: AggregateExpression => agg}
}
@@ -118,7 +118,7 @@ case class GeneratedAggregate(
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case cs @ CombineSum(expr) =>
- val calcType = expr.dataType
+ val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
@@ -129,7 +129,7 @@ case class GeneratedAggregate(
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
val initialValue = Literal.create(null, calcType)
- // Coalasce avoids double calculation...
+ // Coalesce avoids double calculation...
// but really, common sub expression elimination would be better....
val zero = Cast(Literal(0), calcType)
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
@@ -138,15 +138,15 @@ case class GeneratedAggregate(
case UnscaledValue(e) => e
case _ => expr
}
- // partial sum result can be null only when no input rows present
+ // partial sum result can be null only when no input rows present
val updateFunction = If(
IsNotNull(actualExpr),
Coalesce(
Add(
- Coalesce(currentSum :: zero :: Nil),
+ Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)) :: currentSum :: zero :: Nil),
currentSum)
-
+
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
@@ -155,7 +155,7 @@ case class GeneratedAggregate(
}
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
-
+
case m @ Max(expr) =>
val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
val initialValue = Literal.create(null, expr.dataType)
@@ -214,18 +214,18 @@ case class GeneratedAggregate(
}.toMap
val namedGroups = groupingExpressions.zipWithIndex.map {
- case (ne: NamedExpression, _) => (ne, ne)
- case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
+ case (ne: NamedExpression, _) => (ne, ne.toAttribute)
+ case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
}
- val groupMap: Map[Expression, Attribute] =
- namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
-
// The set of expressions that produce the final output given the aggregation buffer and the
// grouping expressions.
val resultExpressions = aggregateExpressions.map(_.transform {
case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
- case e: Expression if groupMap.contains(e) => groupMap(e)
+ case e: Expression =>
+ namedGroups.collectFirst {
+ case (expr, attr) if expr semanticEquals e => attr
+ }.getOrElse(e)
})
val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
@@ -265,7 +265,7 @@ case class GeneratedAggregate(
val resultProjectionBuilder =
newMutableProjection(
resultExpressions,
- (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
+ namedGroups.map(_._2) ++ computationSchema)
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
val joinedRow = new JoinedRow3
@@ -273,7 +273,7 @@ case class GeneratedAggregate(
if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
- var currentRow: Row = null
+ var currentRow: InternalRow = null
updateProjection.target(buffer)
while (iter.hasNext) {
@@ -295,19 +295,19 @@ case class GeneratedAggregate(
)
while (iter.hasNext) {
- val currentRow: Row = iter.next()
- val groupKey: Row = groupProjection(currentRow)
+ val currentRow: InternalRow = iter.next()
+ val groupKey: InternalRow = groupProjection(currentRow)
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}
- new Iterator[Row] {
+ new Iterator[InternalRow] {
private[this] val mapIterator = aggregationMap.iterator()
private[this] val resultProjection = resultProjectionBuilder()
def hasNext: Boolean = mapIterator.hasNext
- def next(): Row = {
+ def next(): InternalRow = {
val entry = mapIterator.next()
val result = resultProjection(joinedRow(entry.key, entry.value))
if (hasNext) {
@@ -326,9 +326,9 @@ case class GeneratedAggregate(
if (unsafeEnabled) {
log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
}
- val buffers = new java.util.HashMap[Row, MutableRow]()
+ val buffers = new java.util.HashMap[InternalRow, MutableRow]()
- var currentRow: Row = null
+ var currentRow: InternalRow = null
while (iter.hasNext) {
currentRow = iter.next()
val currentGroup = groupProjection(currentRow)
@@ -342,13 +342,13 @@ case class GeneratedAggregate(
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
}
- new Iterator[Row] {
+ new Iterator[InternalRow] {
private[this] val resultIterator = buffers.entrySet.iterator()
private[this] val resultProjection = resultProjectionBuilder()
def hasNext: Boolean = resultIterator.hasNext
- def next(): Row = {
+ def next(): InternalRow = {
val currentGroup = resultIterator.next()
resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
index 03bee80ad7f3..cd341180b610 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
@@ -19,18 +19,20 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.expressions.Attribute
/**
* Physical plan node for scanning data from a local collection.
*/
-private[sql] case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode {
+private[sql] case class LocalTableScan(
+ output: Seq[Attribute],
+ rows: Seq[InternalRow]) extends LeafNode {
private lazy val rdd = sqlContext.sparkContext.parallelize(rows)
- protected override def doExecute(): RDD[Row] = rdd
+ protected override def doExecute(): RDD[InternalRow] = rdd
override def executeCollect(): Array[Row] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 435ac011178d..2b8d30294293 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
@@ -79,11 +80,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
/**
- * Returns the result of this query as an RDD[Row] by delegating to doExecute
+ * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute
* after adding query plan information to created RDDs for visualization.
* Concrete implementations of SparkPlan should override doExecute instead.
*/
- final def execute(): RDD[Row] = {
+ final def execute(): RDD[InternalRow] = {
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
doExecute()
}
@@ -91,9 +92,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Overridden by concrete implementations of SparkPlan.
- * Produces the result of the query as an RDD[Row]
+ * Produces the result of the query as an RDD[InternalRow]
*/
- protected def doExecute(): RDD[Row]
+ protected def doExecute(): RDD[InternalRow]
/**
* Runs this query returning the result as an array.
@@ -117,7 +118,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val childRDD = execute().map(_.copy())
- val buf = new ArrayBuffer[Row]
+ val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (buf.size < n && partsScanned < totalParts) {
@@ -140,7 +141,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val sc = sqlContext.sparkContext
val res =
- sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
+ sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p,
+ allowLocal = false)
res.foreach(buf ++= _.take(n - buf.size))
partsScanned += numPartsToTry
@@ -154,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
- if (codegenEnabled) {
+ if (codegenEnabled && expressions.forall(_.isThreadSafe)) {
GenerateProjection.generate(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
@@ -166,7 +168,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
- if(codegenEnabled) {
+ if(codegenEnabled && expressions.forall(_.isThreadSafe)) {
+
GenerateMutableProjection.generate(expressions, inputSchema)
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
@@ -175,15 +178,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
protected def newPredicate(
- expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
- if (codegenEnabled) {
+ expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
+ if (codegenEnabled && expression.isThreadSafe) {
GeneratePredicate.generate(expression, inputSchema)
} else {
InterpretedPredicate.create(expression, inputSchema)
}
}
- protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
+ protected def newOrdering(
+ order: Seq[SortOrder],
+ inputSchema: Seq[Attribute]): Ordering[InternalRow] = {
if (codegenEnabled) {
GenerateOrdering.generate(order, inputSchema)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index 256d527d7b63..15b6936acd59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -20,15 +20,15 @@ package org.apache.spark.sql.execution
import java.io._
import java.math.{BigDecimal, BigInteger}
import java.nio.ByteBuffer
-import java.sql.Timestamp
import scala.reflect.ClassTag
-import org.apache.spark.serializer._
import org.apache.spark.Logging
+import org.apache.spark.serializer._
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, MutableRow, SpecificMutableRow}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in
@@ -86,7 +86,6 @@ private[sql] class Serializer2SerializationStream(
private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
- hasKeyOrdering: Boolean,
in: InputStream)
extends DeserializationStream with Logging {
@@ -96,14 +95,9 @@ private[sql] class Serializer2DeserializationStream(
if (schema == null) {
() => null
} else {
- if (hasKeyOrdering) {
- // We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
- () => new GenericMutableRow(schema.length)
- } else {
- // It is safe to reuse the mutable row.
- val mutableRow = new SpecificMutableRow(schema)
- () => mutableRow
- }
+ // It is safe to reuse the mutable row.
+ val mutableRow = new SpecificMutableRow(schema)
+ () => mutableRow
}
}
@@ -133,8 +127,7 @@ private[sql] class Serializer2DeserializationStream(
private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType],
- valueSchema: Array[DataType],
- hasKeyOrdering: Boolean)
+ valueSchema: Array[DataType])
extends SerializerInstance {
def serialize[T: ClassTag](t: T): ByteBuffer =
@@ -151,7 +144,7 @@ private[sql] class SparkSqlSerializer2Instance(
}
def deserializeStream(s: InputStream): DeserializationStream = {
- new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
+ new Serializer2DeserializationStream(keySchema, valueSchema, s)
}
}
@@ -164,14 +157,13 @@ private[sql] class SparkSqlSerializer2Instance(
*/
private[sql] class SparkSqlSerializer2(
keySchema: Array[DataType],
- valueSchema: Array[DataType],
- hasKeyOrdering: Boolean)
+ valueSchema: Array[DataType])
extends Serializer
with Logging
with Serializable{
def newInstance(): SerializerInstance =
- new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)
+ new SparkSqlSerializer2Instance(keySchema, valueSchema)
override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers
@@ -304,11 +296,7 @@ private[sql] object SparkSqlSerializer2 {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
- val timestamp = row.getAs[java.sql.Timestamp](i)
- val time = timestamp.getTime
- val nanos = timestamp.getNanos
- out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value.
- out.writeInt(nanos) // Write the nanoseconds part.
+ out.writeLong(row.getAs[Long](i))
}
case StringType =>
@@ -429,11 +417,7 @@ private[sql] object SparkSqlSerializer2 {
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
- val time = in.readLong() // Read the milliseconds value.
- val nanos = in.readInt() // Read the nanoseconds part.
- val timestamp = new Timestamp(time)
- timestamp.setNanos(nanos)
- mutableRow.update(i, timestamp)
+ mutableRow.update(i, in.readLong())
}
case StringType =>
@@ -443,7 +427,7 @@ private[sql] object SparkSqlSerializer2 {
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
- mutableRow.update(i, UTF8String(bytes))
+ mutableRow.update(i, UTF8String.fromBytes(bytes))
}
case BinaryType =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index d0a1ad00560d..422992d019c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -203,7 +203,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
protected lazy val singleRowRdd =
- sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
+ sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1)
object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
@@ -284,8 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: RunnableCommand => ExecutedCommand(r) :: Nil
case logical.Distinct(child) =>
- execution.Distinct(partial = false,
- execution.Distinct(partial = true, planLater(child))) :: Nil
+ throw new IllegalStateException(
+ "logical distinct operator should have been replaced by aggregate in the optimizer")
case logical.Repartition(numPartitions, shuffle, child) =>
execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil
case logical.SortPartitions(sortExprs, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index c4327ce262ac..fd6f1d7ae125 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -20,9 +20,8 @@ package org.apache.spark.sql.execution
import java.util
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, ClusteredDistribution, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.util.collection.CompactBuffer
/**
@@ -112,16 +111,16 @@ case class Window(
}
}
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
- new Iterator[Row] {
+ new Iterator[InternalRow] {
// Although input rows are grouped based on windowSpec.partitionSpec, we need to
// know when we have a new partition.
// This is to manually construct an ordering that can be used to compare rows.
// TODO: We may want to have a newOrdering that takes BoundReferences.
// So, we can take advantave of code gen.
- private val partitionOrdering: Ordering[Row] =
+ private val partitionOrdering: Ordering[InternalRow] =
RowOrdering.forSchema(windowSpec.partitionSpec.map(_.dataType))
// This is used to project expressions for the partition specification.
@@ -137,13 +136,13 @@ case class Window(
// The number of buffered rows in the inputRowBuffer (the size of the current partition).
var partitionSize: Int = 0
// The buffer used to buffer rows in a partition.
- var inputRowBuffer: CompactBuffer[Row] = _
+ var inputRowBuffer: CompactBuffer[InternalRow] = _
// The partition key of the current partition.
- var currentPartitionKey: Row = _
+ var currentPartitionKey: InternalRow = _
// The partition key of next partition.
- var nextPartitionKey: Row = _
+ var nextPartitionKey: InternalRow = _
// The first row of next partition.
- var firstRowInNextPartition: Row = _
+ var firstRowInNextPartition: InternalRow = _
// Indicates if this partition is the last one in the iter.
var lastPartition: Boolean = false
@@ -316,7 +315,7 @@ case class Window(
!lastPartition || (rowPosition < partitionSize)
}
- override final def next(): Row = {
+ override final def next(): InternalRow = {
if (hasNext) {
if (rowPosition == partitionSize) {
// All rows of this buffer have been consumed.
@@ -353,7 +352,7 @@ case class Window(
// Fetch the next partition.
private def fetchNextPartition(): Unit = {
// Create a new buffer for input rows.
- inputRowBuffer = new CompactBuffer[Row]()
+ inputRowBuffer = new CompactBuffer[InternalRow]()
// We already have the first row for this partition
// (recorded in firstRowInNextPartition). Add it back.
inputRowBuffer += firstRowInNextPartition
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6cb67b4bbbb6..7aedd630e387 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,16 +17,17 @@
package org.apache.spark.sql.execution
-import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.util.collection.ExternalSorter
+import org.apache.spark.util.{CompletionIterator, MutablePair}
+import org.apache.spark.{HashPartitioner, SparkEnv}
/**
* :: DeveloperApi ::
@@ -37,7 +38,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
@transient lazy val buildProjection = newMutableProjection(projectList, child.output)
- protected override def doExecute(): RDD[Row] = child.execute().mapPartitions { iter =>
+ protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
val resuableProjection = buildProjection()
iter.map(resuableProjection)
}
@@ -52,9 +53,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
- @transient lazy val conditionEvaluator: (Row) => Boolean = newPredicate(condition, child.output)
+ @transient lazy val conditionEvaluator: (InternalRow) => Boolean =
+ newPredicate(condition, child.output)
- protected override def doExecute(): RDD[Row] = child.execute().mapPartitions { iter =>
+ protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
iter.filter(conditionEvaluator)
}
@@ -65,7 +67,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
* :: DeveloperApi ::
* Sample the dataset.
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
- * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
@@ -83,7 +85,7 @@ case class Sample(
override def output: Seq[Attribute] = child.output
// TODO: How to pick seed?
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
} else {
@@ -99,7 +101,8 @@ case class Sample(
case class Union(children: Seq[SparkPlan]) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
override def output: Seq[Attribute] = children.head.output
- protected override def doExecute(): RDD[Row] = sparkContext.union(children.map(_.execute()))
+ protected override def doExecute(): RDD[InternalRow] =
+ sparkContext.union(children.map(_.execute()))
}
/**
@@ -124,19 +127,19 @@ case class Limit(limit: Int, child: SparkPlan)
override def executeCollect(): Array[Row] = child.executeTake(limit)
- protected override def doExecute(): RDD[Row] = {
- val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) {
+ protected override def doExecute(): RDD[InternalRow] = {
+ val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter =>
iter.take(limit).map(row => (false, row.copy()))
}
} else {
child.execute().mapPartitions { iter =>
- val mutablePair = new MutablePair[Boolean, Row]()
+ val mutablePair = new MutablePair[Boolean, InternalRow]()
iter.take(limit).map(row => mutablePair.update(false, row))
}
}
val part = new HashPartitioner(1)
- val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
+ val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
shuffled.mapPartitions(_.take(limit).map(_._2))
}
@@ -157,7 +160,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
- private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ private def collectData(): Array[InternalRow] =
+ child.execute().map(_.copy()).takeOrdered(limit)(ord)
override def executeCollect(): Array[Row] = {
val converter = CatalystTypeConverters.createToScalaConverter(schema)
@@ -166,7 +170,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- protected override def doExecute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)
+ protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
override def outputOrdering: Seq[SortOrder] = sortOrder
}
@@ -186,7 +190,7 @@ case class Sort(
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
- protected override def doExecute(): RDD[Row] = attachTree(this, "sort") {
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
child.execute().mapPartitions( { iterator =>
val ordering = newOrdering(sortOrder, child.output)
iterator.map(_.copy()).toArray.sorted(ordering).iterator
@@ -214,14 +218,14 @@ case class ExternalSort(
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
- protected override def doExecute(): RDD[Row] = attachTree(this, "sort") {
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
child.execute().mapPartitions( { iterator =>
val ordering = newOrdering(sortOrder, child.output)
- val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering))
+ val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
sorter.insertAll(iterator.map(r => (r.copy, null)))
val baseIterator = sorter.iterator.map(_._1)
// TODO(marmbrus): The complex type signature below thwarts inference for no reason.
- CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop())
+ CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
}, preservesPartitioning = true)
}
@@ -230,37 +234,6 @@ case class ExternalSort(
override def outputOrdering: Seq[SortOrder] = sortOrder
}
-/**
- * :: DeveloperApi ::
- * Computes the set of distinct input rows using a HashSet.
- * @param partial when true the distinct operation is performed partially, per partition, without
- * shuffling the data.
- * @param child the input query plan.
- */
-@DeveloperApi
-case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode {
- override def output: Seq[Attribute] = child.output
-
- override def requiredChildDistribution: Seq[Distribution] =
- if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil
-
- protected override def doExecute(): RDD[Row] = {
- child.execute().mapPartitions { iter =>
- val hashSet = new scala.collection.mutable.HashSet[Row]()
-
- var currentRow: Row = null
- while (iter.hasNext) {
- currentRow = iter.next()
- if (!hashSet.contains(currentRow)) {
- hashSet.add(currentRow.copy())
- }
- }
-
- hashSet.iterator
- }
- }
-}
-
/**
* :: DeveloperApi ::
* Return a new RDD that has exactly `numPartitions` partitions.
@@ -270,7 +243,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan)
extends UnaryNode {
override def output: Seq[Attribute] = child.output
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
child.execute().map(_.copy()).coalesce(numPartitions, shuffle)
}
}
@@ -285,7 +258,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan)
case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
left.execute().map(_.copy()).subtract(right.execute().map(_.copy()))
}
}
@@ -299,7 +272,7 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {
case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode {
override def output: Seq[Attribute] = children.head.output
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
left.execute().map(_.copy()).intersection(right.execute().map(_.copy()))
}
}
@@ -314,5 +287,5 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode {
case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan {
def children: Seq[SparkPlan] = child :: Nil
- protected override def doExecute(): RDD[Row] = child.execute()
+ protected override def doExecute(): RDD[InternalRow] = child.execute()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 49b361e96b2d..653792ea2e53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext}
/**
* A logical command that is executed for its side-effects. `RunnableCommand`s are
@@ -64,9 +64,9 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan
override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val converted = sideEffectResult.map(r =>
- CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row])
+ CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[InternalRow])
sqlContext.sparkContext.parallelize(converted, 1)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index dffb265601bd..3ee4033baee2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -18,13 +18,15 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.unsafe.types.UTF8String
import scala.collection.mutable.HashSet
import org.apache.spark.{AccumulatorParam, Accumulator}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.{SQLConf, SQLContext, DataFrame, Row}
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.types._
@@ -125,11 +127,11 @@ package object debug {
}
}
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
- new Iterator[Row] {
+ new Iterator[InternalRow] {
def hasNext: Boolean = iter.hasNext
- def next(): Row = {
+ def next(): InternalRow = {
val currentRow = iter.next()
tupleCount += 1
var i = 0
@@ -154,7 +156,7 @@ package object debug {
def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match {
case (null, _) =>
- case (row: Row, StructType(fields)) =>
+ case (row: InternalRow, StructType(fields)) =>
row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (s: Seq[_], ArrayType(elemType, _)) =>
s.foreach(typeCheck(_, elemType))
@@ -170,6 +172,8 @@ package object debug {
case (_: Short, ShortType) =>
case (_: Boolean, BooleanType) =>
case (_: Double, DoubleType) =>
+ case (_: Int, DateType) =>
+ case (_: Long, TimestampType) =>
case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType)
case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t")
@@ -193,7 +197,7 @@ package object debug {
def children: List[SparkPlan] = child :: Nil
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
child.execute().map { row =>
try typeCheck(row, child.schema) catch {
case e: Exception =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
index e228a60c9029..3b217348b7b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.expressions
import org.apache.spark.TaskContext
-import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression}
+import org.apache.spark.sql.catalyst.expressions.{InternalRow, LeafExpression}
import org.apache.spark.sql.types.{LongType, DataType}
/**
@@ -43,9 +43,11 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {
override def dataType: DataType = LongType
- override def eval(input: Row): Long = {
+ override def eval(input: InternalRow): Long = {
val currentCount = count
count += 1
(TaskContext.get().partitionId().toLong << 33) + currentCount
}
+
+ override def isThreadSafe: Boolean = false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
index 1272793f88cd..12c2eed0d6b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.expressions
import org.apache.spark.TaskContext
-import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row}
+import org.apache.spark.sql.catalyst.expressions.{LeafExpression, InternalRow}
import org.apache.spark.sql.types.{IntegerType, DataType}
@@ -31,5 +31,5 @@ private[sql] case object SparkPartitionID extends LeafExpression {
override def dataType: DataType = IntegerType
- override def eval(input: Row): Int = TaskContext.get().partitionId()
+ override def eval(input: InternalRow): Int = TaskContext.get().partitionId()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index b8b12be8756f..2d2e1b92b86b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -17,16 +17,15 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.rdd.RDD
-import org.apache.spark.util.ThreadUtils
-
import scala.concurrent._
import scala.concurrent.duration._
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.{Expression, InternalRow}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.ThreadUtils
/**
* :: DeveloperApi ::
@@ -61,12 +60,12 @@ case class BroadcastHashJoin(
@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
- val input: Array[Row] = buildPlan.execute().map(_.copy()).collect()
+ val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val broadcastRelation = Await.result(broadcastFuture, timeout)
streamedPlan.execute().mapPartitions { streamedIter =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index a32e5fc4f7ea..412a3d4178e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
/**
@@ -38,10 +38,10 @@ case class BroadcastLeftSemiJoinHash(
override def output: Seq[Attribute] = left.output
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator
- val hashSet = new java.util.HashSet[Row]()
- var currentRow: Row = null
+ val hashSet = new java.util.HashSet[InternalRow]()
+ var currentRow: InternalRow = null
// Create a Hash set of buildKeys
while (buildIter.hasNext) {
@@ -50,7 +50,8 @@ case class BroadcastLeftSemiJoinHash(
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
- hashSet.add(rowKey)
+ // rowKey may be not serializable (from codegen)
+ hashSet.add(rowKey.copy())
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index caad3dfbe1c5..0b2cf8e12a6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -61,13 +61,14 @@ case class BroadcastNestedLoopJoin(
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map(_.copy())
+ .collect().toIndexedSeq)
/** All rows that either match both-way, or rows from streamed joined with nulls. */
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
- val matchedRows = new CompactBuffer[Row]
+ val matchedRows = new CompactBuffer[InternalRow]
// TODO: Use Spark's BitSet.
val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
@@ -118,8 +119,8 @@ case class BroadcastNestedLoopJoin(
val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
/** Rows from broadcasted joined with nulls. */
- val broadcastRowsWithNulls: Seq[Row] = {
- val buf: CompactBuffer[Row] = new CompactBuffer()
+ val broadcastRowsWithNulls: Seq[InternalRow] = {
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val rel = broadcastedRelation.value
while (i < rel.length) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index 191c00cb55da..261b4724159f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output ++ right.output
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val leftResults = left.execute().map(_.copy())
val rightResults = right.execute().map(_.copy())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 851de1685509..3a4196a90d14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -49,11 +49,13 @@ trait HashJoin {
@transient protected lazy val streamSideKeyGenerator: () => MutableProjection =
newMutableProjection(streamedKeys, streamedPlan.output)
- protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] =
+ protected def hashJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] =
{
- new Iterator[Row] {
- private[this] var currentStreamedRow: Row = _
- private[this] var currentHashMatches: CompactBuffer[Row] = _
+ new Iterator[InternalRow] {
+ private[this] var currentStreamedRow: InternalRow = _
+ private[this] var currentHashMatches: CompactBuffer[InternalRow] = _
private[this] var currentMatchPosition: Int = -1
// Mutable per row objects.
@@ -65,7 +67,7 @@ trait HashJoin {
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
(streamIter.hasNext && fetchNext())
- override final def next(): Row = {
+ override final def next(): InternalRow = {
val ret = buildSide match {
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 45574392996c..bce0e8d70a57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -48,7 +48,8 @@ case class HashOuterJoin(
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
- case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ case x =>
+ throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
}
override def requiredChildDistribution: Seq[ClusteredDistribution] =
@@ -63,24 +64,27 @@ case class HashOuterJoin(
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case x =>
- throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
}
}
- @transient private[this] lazy val DUMMY_LIST = Seq[Row](null)
- @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]
+ @transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null)
+ @transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow]
@transient private[this] lazy val leftNullRow = new GenericRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericRow(right.output.length)
@transient private[this] lazy val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+ condition.map(
+ newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true)
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
private[this] def leftOuterIterator(
- key: Row, joinedRow: JoinedRow, rightIter: Iterable[Row]): Iterator[Row] = {
- val ret: Iterable[Row] = {
+ key: InternalRow,
+ joinedRow: JoinedRow,
+ rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
+ val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = rightIter.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
@@ -98,12 +102,15 @@ case class HashOuterJoin(
}
private[this] def rightOuterIterator(
- key: Row, leftIter: Iterable[Row], joinedRow: JoinedRow): Iterator[Row] = {
+ key: InternalRow,
+ leftIter: Iterable[InternalRow],
+ joinedRow: JoinedRow): Iterator[InternalRow] = {
- val ret: Iterable[Row] = {
+ val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) => joinedRow.copy
+ case l if boundCondition(joinedRow.withLeft(l)) =>
+ joinedRow.copy
}
if (temp.size == 0) {
joinedRow.withLeft(leftNullRow).copy :: Nil
@@ -118,14 +125,14 @@ case class HashOuterJoin(
}
private[this] def fullOuterIterator(
- key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row],
- joinedRow: JoinedRow): Iterator[Row] = {
+ key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow],
+ joinedRow: JoinedRow): Iterator[InternalRow] = {
if (!key.anyNull) {
// Store the positions of records in right, if one of its associated row satisfy
// the join condition.
val rightMatchedSet = scala.collection.mutable.Set[Int]()
- leftIter.iterator.flatMap[Row] { l =>
+ leftIter.iterator.flatMap[InternalRow] { l =>
joinedRow.withLeft(l)
var matched = false
rightIter.zipWithIndex.collect {
@@ -156,24 +163,25 @@ case class HashOuterJoin(
joinedRow(leftNullRow, r).copy()
}
} else {
- leftIter.iterator.map[Row] { l =>
+ leftIter.iterator.map[InternalRow] { l =>
joinedRow(l, rightNullRow).copy()
- } ++ rightIter.iterator.map[Row] { r =>
+ } ++ rightIter.iterator.map[InternalRow] { r =>
joinedRow(leftNullRow, r).copy()
}
}
}
private[this] def buildHashTable(
- iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
- val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
+ iter: Iterator[InternalRow],
+ keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = {
+ val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]()
while (iter.hasNext) {
val currentRow = iter.next()
val rowKey = keyGenerator(currentRow)
var existingMatchList = hashTable.get(rowKey)
if (existingMatchList == null) {
- existingMatchList = new CompactBuffer[Row]()
+ existingMatchList = new CompactBuffer[InternalRow]()
hashTable.put(rowKey, existingMatchList)
}
@@ -183,7 +191,7 @@ case class HashOuterJoin(
hashTable
}
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val joinedRow = new JoinedRow()
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// TODO this probably can be replaced by external sort (sort merged join?)
@@ -216,7 +224,8 @@ case class HashOuterJoin(
rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow)
}
- case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ case x =>
+ throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index ab84c123e0c0..e18c81797513 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import java.io.{ObjectInput, ObjectOutput, Externalizable}
import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
+import org.apache.spark.sql.catalyst.expressions.{Projection, InternalRow}
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.util.collection.CompactBuffer
@@ -30,7 +30,7 @@ import org.apache.spark.util.collection.CompactBuffer
* object.
*/
private[joins] sealed trait HashedRelation {
- def get(key: Row): CompactBuffer[Row]
+ def get(key: InternalRow): CompactBuffer[InternalRow]
// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
@@ -54,12 +54,12 @@ private[joins] sealed trait HashedRelation {
* A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
*/
private[joins] final class GeneralHashedRelation(
- private var hashTable: JavaHashMap[Row, CompactBuffer[Row]])
+ private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {
def this() = this(null) // Needed for serialization
- override def get(key: Row): CompactBuffer[Row] = hashTable.get(key)
+ override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key)
override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
@@ -75,17 +75,18 @@ private[joins] final class GeneralHashedRelation(
* A specialized [[HashedRelation]] that maps key into a single value. This implementation
* assumes the key is unique.
*/
-private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[Row, Row])
+private[joins]
+final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
extends HashedRelation with Externalizable {
def this() = this(null) // Needed for serialization
- override def get(key: Row): CompactBuffer[Row] = {
+ override def get(key: InternalRow): CompactBuffer[InternalRow] = {
val v = hashTable.get(key)
if (v eq null) null else CompactBuffer(v)
}
- def getValue(key: Row): Row = hashTable.get(key)
+ def getValue(key: InternalRow): InternalRow = hashTable.get(key)
override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
@@ -103,13 +104,13 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHa
private[joins] object HashedRelation {
def apply(
- input: Iterator[Row],
+ input: Iterator[InternalRow],
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {
// TODO: Use Spark's HashMap implementation.
- val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate)
- var currentRow: Row = null
+ val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
+ var currentRow: InternalRow = null
// Whether the join key is unique. If the key is unique, we can convert the underlying
// hash map into one specialized for this.
@@ -122,7 +123,7 @@ private[joins] object HashedRelation {
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[Row]()
+ val newMatchList = new CompactBuffer[InternalRow]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
@@ -134,7 +135,7 @@ private[joins] object HashedRelation {
}
if (keyIsUnique) {
- val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size)
+ val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size)
val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index 036423e6faea..2a6d4d1ab08b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -47,7 +47,7 @@ case class LeftSemiJoinBNL(
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRelation =
sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 8ad27eae80ff..20d74270afb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow}
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -42,10 +42,10 @@ case class LeftSemiJoinHash(
override def output: Seq[Attribute] = left.output
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashSet = new java.util.HashSet[Row]()
- var currentRow: Row = null
+ val hashSet = new java.util.HashSet[InternalRow]()
+ var currentRow: InternalRow = null
// Create a Hash set of buildKeys
while (buildIter.hasNext) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 219525d9d85f..5439e10a60b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -43,7 +43,7 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
hashJoin(streamIter, hashed)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 1a39fb4b9660..2abe65a71813 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -21,9 +21,7 @@ import java.util.NoSuchElementException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.collection.CompactBuffer
@@ -60,29 +58,29 @@ case class SortMergeJoin(
private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
keys.map(SortOrder(_, Ascending))
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val leftResults = left.execute().map(_.copy())
val rightResults = right.execute().map(_.copy())
leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
- new Iterator[Row] {
+ new Iterator[InternalRow] {
// Mutable per row objects.
private[this] val joinRow = new JoinedRow5
- private[this] var leftElement: Row = _
- private[this] var rightElement: Row = _
- private[this] var leftKey: Row = _
- private[this] var rightKey: Row = _
- private[this] var rightMatches: CompactBuffer[Row] = _
+ private[this] var leftElement: InternalRow = _
+ private[this] var rightElement: InternalRow = _
+ private[this] var leftKey: InternalRow = _
+ private[this] var rightKey: InternalRow = _
+ private[this] var rightMatches: CompactBuffer[InternalRow] = _
private[this] var rightPosition: Int = -1
private[this] var stop: Boolean = false
- private[this] var matchKey: Row = _
+ private[this] var matchKey: InternalRow = _
// initialize iterator
initialize()
override final def hasNext: Boolean = nextMatchingPair()
- override final def next(): Row = {
+ override final def next(): InternalRow = {
if (hasNext) {
// we are using the buffered right rows and run down left iterator
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
@@ -145,7 +143,7 @@ case class SortMergeJoin(
fetchLeft()
}
}
- rightMatches = new CompactBuffer[Row]()
+ rightMatches = new CompactBuffer[InternalRow]()
if (stop) {
stop = false
// iterate the right side to buffer all rows that matches
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 55f3ff470901..1ce150ceaf5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -24,18 +24,19 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle.{Pickler, Unpickler}
+import org.apache.spark.{Accumulator, Logging => SparkLogging}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
-import org.apache.spark.{Accumulator, Logging => SparkLogging}
+import org.apache.spark.unsafe.types.UTF8String
/**
* A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]].
@@ -56,8 +57,8 @@ private[spark] case class PythonUDF(
def nullable: Boolean = true
- override def eval(input: Row): Any = {
- sys.error("PythonUDFs can not be directly evaluated.")
+ override def eval(input: InternalRow): Any = {
+ throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")
}
}
@@ -71,43 +72,49 @@ private[spark] case class PythonUDF(
private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Skip EvaluatePython nodes.
- case p: EvaluatePython => p
+ case plan: EvaluatePython => plan
- case l: LogicalPlan =>
+ case plan: LogicalPlan =>
// Extract any PythonUDFs from the current operator.
- val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf})
+ val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
if (udfs.isEmpty) {
// If there aren't any, we are done.
- l
+ plan
} else {
// Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
// If there is more than one, we will add another evaluation operator in a subsequent pass.
- val udf = udfs.head
-
- var evaluation: EvaluatePython = null
-
- // Rewrite the child that has the input required for the UDF
- val newChildren = l.children.map { child =>
- // Check to make sure that the UDF can be evaluated with only the input of this child.
- // Other cases are disallowed as they are ambiguous or would require a cartisian product.
- if (udf.references.subsetOf(child.outputSet)) {
- evaluation = EvaluatePython(udf, child)
- evaluation
- } else if (udf.references.intersect(child.outputSet).nonEmpty) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- child
- }
+ udfs.find(_.resolved) match {
+ case Some(udf) =>
+ var evaluation: EvaluatePython = null
+
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ // Check to make sure that the UDF can be evaluated with only the input of this child.
+ // Other cases are disallowed as they are ambiguous or would require a cartesian
+ // product.
+ if (udf.references.subsetOf(child.outputSet)) {
+ evaluation = EvaluatePython(udf, child)
+ evaluation
+ } else if (udf.references.intersect(child.outputSet).nonEmpty) {
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+ } else {
+ child
+ }
+ }
+
+ assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")
+
+ // Trim away the new UDF value if it was only used for filtering or something.
+ logical.Project(
+ plan.output,
+ plan.transformExpressions {
+ case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
+ }.withNewChildren(newChildren))
+
+ case None =>
+ // If there is no Python UDF that is resolved, skip this round.
+ plan
}
-
- assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")
-
- // Trim away the new UDF value if it was only used for filtering or something.
- logical.Project(
- l.output,
- l.transformExpressions {
- case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
- }.withNewChildren(newChildren))
}
}
}
@@ -142,6 +149,7 @@ object EvaluatePython {
case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
case (date: Int, DateType) => DateUtils.toJavaDate(date)
+ case (t: Long, TimestampType) => DateUtils.toJavaTimestamp(t)
case (s: UTF8String, StringType) => s.toString
// Pyrolite can handle Timestamp and Decimal
@@ -180,10 +188,12 @@ object EvaluatePython {
}): Row
case (c: java.util.Calendar, DateType) =>
- DateUtils.fromJavaDate(new java.sql.Date(c.getTime().getTime()))
+ DateUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis))
case (c: java.util.Calendar, TimestampType) =>
- new java.sql.Timestamp(c.getTime().getTime())
+ c.getTimeInMillis * 10000L
+ case (t: java.sql.Timestamp, TimestampType) =>
+ DateUtils.fromJavaTimestamp(t)
case (_, udt: UserDefinedType[_]) =>
fromJava(obj, udt.sqlType)
@@ -195,8 +205,10 @@ object EvaluatePython {
case (c: Long, IntegerType) => c.toInt
case (c: Int, LongType) => c.toLong
case (c: Double, FloatType) => c.toFloat
- case (c: String, StringType) => UTF8String(c)
- case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString)
+ case (c: String, StringType) => UTF8String.fromString(c)
+ case (c, StringType) =>
+ // If we get here, c is not a string. Call toString on it.
+ UTF8String.fromString(c.toString)
case (c, _) => c
}
@@ -230,7 +242,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
def children: Seq[SparkPlan] = child :: Nil
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
val childResults = child.execute().map(_.copy())
val parent = childResults.mapPartitions { iter =>
@@ -265,7 +277,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val row = new GenericMutableRow(1)
iter.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
- row: Row
+ row: InternalRow
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index fe8a81e3d043..8df1da037c43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.stat
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
-import org.apache.spark.sql.{Column, DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}
+import org.apache.spark.sql.{Column, DataFrame}
private[sql] object FrequentItems extends Logging {
@@ -62,7 +63,7 @@ private[sql] object FrequentItems extends Logging {
}
/**
- * Finding frequent items for columns, possibly with false positives. Using the
+ * Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* The `support` should be greater than 1e-4.
@@ -75,7 +76,7 @@ private[sql] object FrequentItems extends Logging {
* @return A Local DataFrame with the Array of frequent items for each column.
*/
private[sql] def singlePassFreqItems(
- df: DataFrame,
+ df: DataFrame,
cols: Seq[String],
support: Double): DataFrame = {
require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
@@ -88,7 +89,7 @@ private[sql] object FrequentItems extends Logging {
val index = originalSchema.fieldIndex(name)
(name, originalSchema.fields(index).dataType)
}
-
+
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
@@ -110,7 +111,7 @@ private[sql] object FrequentItems extends Logging {
}
)
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
- val resultRow = Row(justItems : _*)
+ val resultRow = InternalRow(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index d22f5fd2d439..93383e5a62f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -18,14 +18,14 @@
package org.apache.spark.sql.execution.stat
import org.apache.spark.Logging
-import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.{Row, Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
private[sql] object StatFunctions extends Logging {
-
+
/** Calculate the Pearson Correlation Coefficient for the given columns */
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols)
@@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging {
s"exceed 1e4. Currently $columnSize")
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
val countsRow = new GenericMutableRow(columnSize + 1)
- rows.foreach { row =>
+ rows.foreach { (row: Row) =>
+ // row.get(0) is column 1
+ // row.get(1) is column 2
+ // row.get(3) is the frequency
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
@@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging {
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
val schema = StructType(StructField(tableName, StringType) +: headerNames)
- new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
+ new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index d4003b2d9cbf..e9b60841fc28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -79,3 +79,20 @@ object Window {
}
}
+
+/**
+ * :: Experimental ::
+ * Utility functions for defining window in DataFrames.
+ *
+ * {{{
+ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
+ * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0)
+ *
+ * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING
+ * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3)
+ * }}}
+ *
+ * @since 1.4.0
+ */
+@Experimental
+class Window private() // So we can see Window in JavaDoc.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6dc17bbb2e76..c5b77724aae1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -24,7 +24,6 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.mathfuncs._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -38,6 +37,7 @@ import org.apache.spark.util.Utils
* @groupname normal_funcs Non-aggregate functions
* @groupname math_funcs Math functions
* @groupname window_funcs Window functions
+ * @groupname string_funcs String functions
* @groupname Ungrouped Support functions for DataFrames.
* @since 1.3.0
*/
@@ -945,6 +945,15 @@ object functions {
*/
def cosh(columnName: String): Column = cosh(Column(columnName))
+ /**
+ * Returns the double value that is closer than any other to e, the base of the natural
+ * logarithms.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def e(): Column = EulerNumber()
+
/**
* Computes the exponential of the given value.
*
@@ -1075,7 +1084,7 @@ object functions {
def log(columnName: String): Column = log(Column(columnName))
/**
- * Computes the logarithm of the given value in Base 10.
+ * Computes the logarithm of the given value in base 10.
*
* @group math_funcs
* @since 1.4.0
@@ -1083,7 +1092,7 @@ object functions {
def log10(e: Column): Column = Log10(e.expr)
/**
- * Computes the logarithm of the given value in Base 10.
+ * Computes the logarithm of the given value in base 10.
*
* @group math_funcs
* @since 1.4.0
@@ -1106,6 +1115,31 @@ object functions {
*/
def log1p(columnName: String): Column = log1p(Column(columnName))
+ /**
+ * Returns the double value that is closer than any other to pi, the ratio of the circumference
+ * of a circle to its diameter.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def pi(): Column = Pi()
+
+ /**
+ * Computes the logarithm of the given column in base 2.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def log2(expr: Column): Column = Log2(expr.expr)
+
+ /**
+ * Computes the logarithm of the given value in base 2.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def log2(columnName: String): Column = log2(Column(columnName))
+
/**
* Returns the value of the first argument raised to the power of the second argument.
*
@@ -1299,7 +1333,24 @@ object functions {
* @since 1.4.0
*/
def toRadians(columnName: String): Column = toRadians(Column(columnName))
-
+
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // String functions
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Computes the length of a given string value
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def strlen(e: Column): Column = StringLength(e.expr)
+
+ /**
+ * Computes the length of a given string column
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def strlen(columnName: String): Column = strlen(Column(columnName))
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 0bdb68e8ac84..226b143923df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -24,10 +24,11 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow}
+import org.apache.spark.sql.catalyst.expressions.{InternalRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources._
+import org.apache.spark.unsafe.types.UTF8String
/**
* Data corresponding to one partition of a JDBCRDD.
@@ -54,7 +55,7 @@ private[sql] object JDBCRDD extends Logging {
val answer = sqlType match {
// scalastyle:off
case java.sql.Types.ARRAY => null
- case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType.Unlimited }
+ case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
case java.sql.Types.BINARY => BinaryType
case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
case java.sql.Types.BLOB => BinaryType
@@ -210,13 +211,15 @@ private[sql] object JDBCRDD extends Logging {
fqTable: String,
requiredColumns: Array[String],
filters: Array[Filter],
- parts: Array[Partition]): RDD[Row] = {
+ parts: Array[Partition]): RDD[InternalRow] = {
+ val dialect = JdbcDialects.get(url)
+ val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
new JDBCRDD(
sc,
getConnector(driver, url, properties),
pruneSchema(schema, requiredColumns),
fqTable,
- requiredColumns,
+ quotedColumns,
filters,
parts,
properties)
@@ -237,7 +240,7 @@ private[sql] class JDBCRDD(
filters: Array[Filter],
partitions: Array[Partition],
properties: Properties)
- extends RDD[Row](sc, Nil) {
+ extends RDD[InternalRow](sc, Nil) {
/**
* Retrieve the list of partitions corresponding to this RDD.
@@ -262,7 +265,7 @@ private[sql] class JDBCRDD(
}
private def escapeSql(value: String): String =
- if (value == null) null else StringUtils.replace(value, "'", "''")
+ if (value == null) null else StringUtils.replace(value, "'", "''")
/**
* Turns a single Filter into a String representing a SQL expression.
@@ -304,7 +307,7 @@ private[sql] class JDBCRDD(
// Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that
// we don't have to potentially poke around in the Metadata once for every
- // row.
+ // row.
// Is there a better way to do this? I'd rather be using a type that
// contains only the tags I define.
abstract class JDBCConversion
@@ -345,12 +348,12 @@ private[sql] class JDBCRDD(
/**
* Runs the SQL query against the JDBC driver.
*/
- override def compute(thePart: Partition, context: TaskContext): Iterator[Row] = new Iterator[Row]
- {
+ override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] =
+ new Iterator[InternalRow] {
var closed = false
var finished = false
var gotNext = false
- var nextValue: Row = null
+ var nextValue: InternalRow = null
context.addTaskCompletionListener{ context => close() }
val part = thePart.asInstanceOf[JDBCPartition]
@@ -372,7 +375,7 @@ private[sql] class JDBCRDD(
val conversions = getConversions(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
- def getNext(): Row = {
+ def getNext(): InternalRow = {
if (rs.next()) {
var i = 0
while (i < conversions.length) {
@@ -383,7 +386,7 @@ private[sql] class JDBCRDD(
// DateUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos)
if (dateVal != null) {
- mutableRow.update(i, DateUtils.fromJavaDate(dateVal))
+ mutableRow.setInt(i, DateUtils.fromJavaDate(dateVal))
} else {
mutableRow.update(i, null)
}
@@ -415,7 +418,13 @@ private[sql] class JDBCRDD(
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
case StringConversion => mutableRow.setString(i, rs.getString(pos))
- case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos))
+ case TimestampConversion =>
+ val t = rs.getTimestamp(pos)
+ if (t != null) {
+ mutableRow.setLong(i, DateUtils.fromJavaTimestamp(t))
+ } else {
+ mutableRow.update(i, null)
+ }
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion => {
val bytes = rs.getBytes(pos)
@@ -434,7 +443,7 @@ private[sql] class JDBCRDD(
mutableRow
} else {
finished = true
- null.asInstanceOf[Row]
+ null.asInstanceOf[InternalRow]
}
}
@@ -477,7 +486,7 @@ private[sql] class JDBCRDD(
!finished
}
- override def next(): Row = {
+ override def next(): InternalRow = {
if (!hasNext) {
throw new NoSuchElementException("End of stream")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 09d6865457df..4d3aac464c53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -23,10 +23,9 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
/**
* Instructions on how to partition the table among workers.
@@ -54,7 +53,7 @@ private[sql] object JDBCRelation {
if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0))
// Overflow and silliness can happen if you subtract then divide.
// Here we get a little roundoff, but that's (hopefully) OK.
- val stride: Long = (partitioning.upperBound / numPartitions
+ val stride: Long = (partitioning.upperBound / numPartitions
- partitioning.lowerBound / numPartitions)
var i: Int = 0
var currentValue: Long = partitioning.lowerBound
@@ -138,12 +137,12 @@ private[sql] case class JDBCRelation(
table,
requiredColumns,
filters,
- parts)
+ parts).map(_.asInstanceOf[Row])
}
-
+
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(url, table, properties)
- }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 6a169e106b96..8849fc2f1f0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.jdbc
+import java.sql.Types
+
import org.apache.spark.sql.types._
import org.apache.spark.annotation.DeveloperApi
-import java.sql.Types
-
/**
* :: DeveloperApi ::
* A database type definition coupled with the jdbc type needed to send null
@@ -80,6 +80,14 @@ abstract class JdbcDialect {
* @return The new JdbcType if there is an override for this DataType
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None
+
+ /**
+ * Quotes the identifier. This is used to put quotes around the identifier in case the column
+ * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space).
+ */
+ def quoteIdentifier(colName: String): String = {
+ s""""$colName""""
+ }
}
/**
@@ -141,18 +149,19 @@ object JdbcDialects {
@DeveloperApi
class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect {
- require(!dialects.isEmpty)
+ require(dialects.nonEmpty)
- def canHandle(url : String): Boolean =
+ override def canHandle(url : String): Boolean =
dialects.map(_.canHandle(url)).reduce(_ && _)
override def getCatalystType(
- sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
- dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption
-
- override def getJDBCType(dt: DataType): Option[JdbcType] =
- dialects.map(_.getJDBCType(dt)).flatten.headOption
+ sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
+ dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption
+ }
+ override def getJDBCType(dt: DataType): Option[JdbcType] = {
+ dialects.flatMap(_.getJDBCType(dt)).headOption
+ }
}
/**
@@ -161,7 +170,7 @@ class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect {
*/
@DeveloperApi
case object NoopDialect extends JdbcDialect {
- def canHandle(url : String): Boolean = true
+ override def canHandle(url : String): Boolean = true
}
/**
@@ -170,7 +179,7 @@ case object NoopDialect extends JdbcDialect {
*/
@DeveloperApi
case object PostgresDialect extends JdbcDialect {
- def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")
+ override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
@@ -196,7 +205,7 @@ case object PostgresDialect extends JdbcDialect {
*/
@DeveloperApi
case object MySQLDialect extends JdbcDialect {
- def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
+ override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
@@ -208,4 +217,8 @@ case object MySQLDialect extends JdbcDialect {
Some(BooleanType)
} else None
}
+
+ override def quoteIdentifier(colName: String): String = {
+ s"`$colName`"
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index f21dd29aca37..dd8aaf647489 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -240,10 +240,10 @@ package object jdbc {
}
}
}
-
+
def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
- case driver => driver.getClass.getCanonicalName
+ case driver => driver.getClass.getCanonicalName
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
index 06aa19ef09bd..565d10247f10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
@@ -147,7 +147,7 @@ private[sql] object InferSchema {
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
- HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index c772cd1f53e5..69bf13e1e5a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -22,10 +22,10 @@ import java.io.IOException
import org.apache.hadoop.fs.Path
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{StructField, StructType}
-import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
private[sql] class DefaultSource
@@ -154,12 +154,12 @@ private[sql] class JSONRelation(
JacksonParser(
baseRDD(),
schema,
- sqlContext.conf.columnNameOfCorruptRecord)
+ sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row])
} else {
JsonRDD.jsonStringToRow(
baseRDD(),
schema,
- sqlContext.conf.columnNameOfCorruptRecord)
+ sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row])
}
}
@@ -168,12 +168,12 @@ private[sql] class JSONRelation(
JacksonParser(
baseRDD(),
StructType.fromAttributes(requiredColumns),
- sqlContext.conf.columnNameOfCorruptRecord)
+ sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row])
} else {
JsonRDD.jsonStringToRow(
baseRDD(),
StructType.fromAttributes(requiredColumns),
- sqlContext.conf.columnNameOfCorruptRecord)
+ sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row])
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala
index 325f54b6808a..1e6b1198d245 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala
@@ -21,7 +21,7 @@ import scala.collection.Map
import com.fasterxml.jackson.core._
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
private[sql] object JacksonGenerator {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index 0e223758051a..817e8a20b34d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.json
import java.io.ByteArrayOutputStream
-import java.sql.Timestamp
import scala.collection.Map
@@ -29,12 +28,14 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
private[sql] object JacksonParser {
def apply(
json: RDD[String],
schema: StructType,
- columnNameOfCorruptRecords: String): RDD[Row] = {
+ columnNameOfCorruptRecords: String): RDD[InternalRow] = {
parseJson(json, schema, columnNameOfCorruptRecords)
}
@@ -55,7 +56,7 @@ private[sql] object JacksonParser {
convertField(factory, parser, schema)
case (VALUE_STRING, StringType) =>
- UTF8String(parser.getText)
+ UTF8String.fromString(parser.getText)
case (VALUE_STRING, _) if parser.getTextLength < 1 =>
// guard the non string type
@@ -65,17 +66,17 @@ private[sql] object JacksonParser {
DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime)
case (VALUE_STRING, TimestampType) =>
- new Timestamp(DateUtils.stringToTime(parser.getText).getTime)
+ DateUtils.stringToTime(parser.getText).getTime * 10000L
case (VALUE_NUMBER_INT, TimestampType) =>
- new Timestamp(parser.getLongValue)
+ parser.getLongValue * 10000L
case (_, StringType) =>
val writer = new ByteArrayOutputStream()
val generator = factory.createGenerator(writer, JsonEncoding.UTF8)
generator.copyCurrentStructure(parser)
generator.close()
- UTF8String(writer.toByteArray)
+ UTF8String.fromBytes(writer.toByteArray)
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) =>
parser.getFloatValue
@@ -129,7 +130,10 @@ private[sql] object JacksonParser {
*
* Fields in the json that are not defined in the requested schema will be dropped.
*/
- private def convertObject(factory: JsonFactory, parser: JsonParser, schema: StructType): Row = {
+ private def convertObject(
+ factory: JsonFactory,
+ parser: JsonParser,
+ schema: StructType): InternalRow = {
val row = new GenericMutableRow(schema.length)
while (nextUntil(parser, JsonToken.END_OBJECT)) {
schema.getFieldIndex(parser.getCurrentName) match {
@@ -153,7 +157,8 @@ private[sql] object JacksonParser {
valueType: DataType): Map[UTF8String, Any] = {
val builder = Map.newBuilder[UTF8String, Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
- builder += UTF8String(parser.getCurrentName) -> convertField(factory, parser, valueType)
+ builder +=
+ UTF8String.fromString(parser.getCurrentName) -> convertField(factory, parser, valueType)
}
builder.result()
@@ -174,14 +179,14 @@ private[sql] object JacksonParser {
private def parseJson(
json: RDD[String],
schema: StructType,
- columnNameOfCorruptRecords: String): RDD[Row] = {
+ columnNameOfCorruptRecords: String): RDD[InternalRow] = {
- def failedRecord(record: String): Seq[Row] = {
+ def failedRecord(record: String): Seq[InternalRow] = {
// create a row even if no corrupt record column is present
val row = new GenericMutableRow(schema.length)
for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) {
require(schema(corruptIndex).dataType == StringType)
- row.update(corruptIndex, UTF8String(record))
+ row.update(corruptIndex, UTF8String.fromString(record))
}
Seq(row)
@@ -200,7 +205,7 @@ private[sql] object JacksonParser {
// convertField wrap an object into a single value array when necessary.
convertField(factory, parser, ArrayType(schema)) match {
case null => failedRecord(record)
- case list: Seq[Row @unchecked] => list
+ case list: Seq[InternalRow @unchecked] => list
case _ =>
sys.error(
s"Failed to parse record $record. Please make sure that each line of the file " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 95eb1174b1dd..44594c5080ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.json
-import java.sql.Timestamp
-
import scala.collection.Map
import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
@@ -32,13 +30,15 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
private[sql] object JsonRDD extends Logging {
private[sql] def jsonStringToRow(
json: RDD[String],
schema: StructType,
- columnNameOfCorruptRecords: String): RDD[Row] = {
+ columnNameOfCorruptRecords: String): RDD[InternalRow] = {
parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema))
}
@@ -155,7 +155,7 @@ private[sql] object JsonRDD extends Logging {
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
- HiveTypeCoercion.findTightestCommonType(t1, t2) match {
+ HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match {
case Some(commonType) => commonType
case None =>
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
@@ -319,7 +319,7 @@ private[sql] object JsonRDD extends Logging {
parsed
} catch {
case e: JsonProcessingException =>
- Map(columnNameOfCorruptRecords -> UTF8String(record)) :: Nil
+ Map(columnNameOfCorruptRecords -> UTF8String.fromString(record)) :: Nil
}
}
})
@@ -398,11 +398,11 @@ private[sql] object JsonRDD extends Logging {
}
}
- private def toTimestamp(value: Any): Timestamp = {
+ private def toTimestamp(value: Any): Long = {
value match {
- case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong)
- case value: java.lang.Long => new Timestamp(value)
- case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime)
+ case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L
+ case value: java.lang.Long => value * 10000L
+ case value: java.lang.String => DateUtils.stringToTime(value).getTime * 10000L
}
}
@@ -411,7 +411,7 @@ private[sql] object JsonRDD extends Logging {
null
} else {
desiredType match {
- case StringType => UTF8String(toString(value))
+ case StringType => UTF8String.fromString(toString(value))
case _ if value == null || value == "" => null // guard the non string type
case IntegerType => value.asInstanceOf[IntegerType.InternalType]
case LongType => toLong(value)
@@ -425,7 +425,7 @@ private[sql] object JsonRDD extends Logging {
val map = value.asInstanceOf[Map[String, Any]]
map.map {
case (k, v) =>
- (UTF8String(k), enforceCorrectType(v, valueType))
+ (UTF8String.fromString(k), enforceCorrectType(v, valueType))
}.map(identity)
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
case DateType => toDate(value)
@@ -434,7 +434,7 @@ private[sql] object JsonRDD extends Logging {
}
}
- private def asRow(json: Map[String, Any], schema: StructType): Row = {
+ private def asRow(json: Map[String, Any], schema: StructType): InternalRow = {
// TODO: Reuse the row instead of creating a new one for every record.
val row = new GenericMutableRow(schema.fields.length)
schema.fields.zipWithIndex.foreach {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 3f97a11ceb97..4e94fd07a877 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -44,6 +44,7 @@ package object sql {
/**
* Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala.
+ * @deprecated As of 1.3.0, replaced by `DataFrame`.
*/
@deprecated("1.3.0", "use DataFrame")
type SchemaRDD = DataFrame
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala
index f5ce2718bec4..62c4e92ebec6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala
@@ -21,9 +21,9 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
-import parquet.Log
-import parquet.hadoop.util.ContextUtil
-import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat}
+import org.apache.parquet.Log
+import org.apache.parquet.hadoop.util.ContextUtil
+import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat}
private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)
extends ParquetOutputCommitter(outputPath, context) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 1b4196ab0be3..4da5e96b82e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -23,14 +23,16 @@ import java.util.{TimeZone, Calendar}
import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap}
import jodd.datetime.JDateTime
-import parquet.column.Dictionary
-import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
-import parquet.schema.MessageType
+import org.apache.parquet.column.Dictionary
+import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
+import org.apache.parquet.schema.MessageType
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.parquet.CatalystConverter.FieldType
import org.apache.spark.sql.types._
import org.apache.spark.sql.parquet.timestamp.NanoTime
+import org.apache.spark.unsafe.types.UTF8String
/**
* Collection of converters of Parquet types (group and primitive types) that
@@ -77,7 +79,7 @@ private[sql] object CatalystConverter {
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
type ArrayScalaType[T] = Seq[T]
- type StructScalaType[T] = Row
+ type StructScalaType[T] = InternalRow
type MapScalaType[K, V] = Map[K, V]
protected[parquet] def createConverter(
@@ -221,7 +223,7 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
updateField(fieldIndex, value.getBytes)
protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit =
- updateField(fieldIndex, UTF8String(value))
+ updateField(fieldIndex, UTF8String.fromBytes(value))
protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, readTimestamp(value))
@@ -238,13 +240,15 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
*
* @return
*/
- def getCurrentRecord: Row = throw new UnsupportedOperationException
+ def getCurrentRecord: InternalRow = throw new UnsupportedOperationException
/**
* Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in
* a long (i.e. precision <= 18)
+ *
+ * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object.
*/
- protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = {
+ protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = {
val precision = ctype.precisionInfo.get.precision
val scale = ctype.precisionInfo.get.scale
val bytes = value.getBytes
@@ -264,14 +268,14 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
/**
* Read a Timestamp value from a Parquet Int96Value
*/
- protected[parquet] def readTimestamp(value: Binary): Timestamp = {
- CatalystTimestampConverter.convertToTimestamp(value)
+ protected[parquet] def readTimestamp(value: Binary): Long = {
+ DateUtils.fromJavaTimestamp(CatalystTimestampConverter.convertToTimestamp(value))
}
}
/**
* A `parquet.io.api.GroupConverter` that is able to convert a Parquet record
- * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object.
+ * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object.
*
* @param schema The corresponding Catalyst schema in the form of a list of attributes.
*/
@@ -280,7 +284,7 @@ private[parquet] class CatalystGroupConverter(
protected[parquet] val index: Int,
protected[parquet] val parent: CatalystConverter,
protected[parquet] var current: ArrayBuffer[Any],
- protected[parquet] var buffer: ArrayBuffer[Row])
+ protected[parquet] var buffer: ArrayBuffer[InternalRow])
extends CatalystConverter {
def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) =
@@ -289,7 +293,7 @@ private[parquet] class CatalystGroupConverter(
index,
parent,
current = null,
- buffer = new ArrayBuffer[Row](
+ buffer = new ArrayBuffer[InternalRow](
CatalystArrayConverter.INITIAL_ARRAY_SIZE))
/**
@@ -305,7 +309,7 @@ private[parquet] class CatalystGroupConverter(
override val size = schema.size
- override def getCurrentRecord: Row = {
+ override def getCurrentRecord: InternalRow = {
assert(isRootConverter, "getCurrentRecord should only be called in root group converter!")
// TODO: use iterators if possible
// Note: this will ever only be called in the root converter when the record has been
@@ -343,7 +347,7 @@ private[parquet] class CatalystGroupConverter(
/**
* A `parquet.io.api.GroupConverter` that is able to convert a Parquet record
- * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. Note that his
+ * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. Note that his
* converter is optimized for rows of primitive types (non-nested records).
*/
private[parquet] class CatalystPrimitiveRowConverter(
@@ -369,7 +373,7 @@ private[parquet] class CatalystPrimitiveRowConverter(
override val parent = null
// Should be only called in root group converter!
- override def getCurrentRecord: Row = current
+ override def getCurrentRecord: InternalRow = current
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
@@ -399,7 +403,7 @@ private[parquet] class CatalystPrimitiveRowConverter(
current.setInt(fieldIndex, value)
override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit =
- current.update(fieldIndex, value)
+ current.setInt(fieldIndex, value)
override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit =
current.setLong(fieldIndex, value)
@@ -420,10 +424,10 @@ private[parquet] class CatalystPrimitiveRowConverter(
current.update(fieldIndex, value.getBytes)
override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit =
- current.update(fieldIndex, UTF8String(value))
+ current.update(fieldIndex, UTF8String.fromBytes(value))
override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit =
- current.update(fieldIndex, readTimestamp(value))
+ current.setLong(fieldIndex, readTimestamp(value))
override protected[parquet] def updateDecimal(
fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = {
@@ -716,7 +720,7 @@ private[parquet] class CatalystNativeArrayConverter(
override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = {
checkGrowBuffer()
- buffer(elements) = UTF8String(value).asInstanceOf[NativeType]
+ buffer(elements) = UTF8String.fromBytes(value).asInstanceOf[NativeType]
elements += 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
index f0f4e7d147e7..d57b789f5c1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -17,20 +17,23 @@
package org.apache.spark.sql.parquet
+import java.io.Serializable
import java.nio.ByteBuffer
import com.google.common.io.BaseEncoding
import org.apache.hadoop.conf.Configuration
-import parquet.filter2.compat.FilterCompat
-import parquet.filter2.compat.FilterCompat._
-import parquet.filter2.predicate.FilterApi._
-import parquet.filter2.predicate.{FilterApi, FilterPredicate}
-import parquet.io.api.Binary
+import org.apache.parquet.filter2.compat.FilterCompat
+import org.apache.parquet.filter2.compat.FilterCompat._
+import org.apache.parquet.filter2.predicate.FilterApi._
+import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics}
+import org.apache.parquet.filter2.predicate.UserDefinedPredicate
+import org.apache.parquet.io.api.Binary
import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
private[sql] object ParquetFilters {
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
@@ -41,6 +44,18 @@ private[sql] object ParquetFilters {
}.reduceOption(FilterApi.and).map(FilterCompat.get)
}
+ case class SetInFilter[T <: Comparable[T]](
+ valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable {
+
+ override def keep(value: T): Boolean = {
+ value != null && valueSet.contains(value)
+ }
+
+ override def canDrop(statistics: Statistics[T]): Boolean = false
+
+ override def inverseCanDrop(statistics: Statistics[T]): Boolean = false
+ }
+
private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case BooleanType =>
(n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
@@ -153,6 +168,29 @@ private[sql] object ParquetFilters {
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
+ private val makeInSet: PartialFunction[DataType, (String, Set[Any]) => FilterPredicate] = {
+ case IntegerType =>
+ (n: String, v: Set[Any]) =>
+ FilterApi.userDefined(intColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Integer]]))
+ case LongType =>
+ (n: String, v: Set[Any]) =>
+ FilterApi.userDefined(longColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Long]]))
+ case FloatType =>
+ (n: String, v: Set[Any]) =>
+ FilterApi.userDefined(floatColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Float]]))
+ case DoubleType =>
+ (n: String, v: Set[Any]) =>
+ FilterApi.userDefined(doubleColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Double]]))
+ case StringType =>
+ (n: String, v: Set[Any]) =>
+ FilterApi.userDefined(binaryColumn(n),
+ SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes))))
+ case BinaryType =>
+ (n: String, v: Set[Any]) =>
+ FilterApi.userDefined(binaryColumn(n),
+ SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[Array[Byte]]))))
+ }
+
/**
* Converts data sources filters to Parquet filter predicates.
*/
@@ -284,6 +322,9 @@ private[sql] object ParquetFilters {
case Not(pred) =>
createFilter(pred).map(FilterApi.not)
+ case InSet(NamedExpression(name, dataType), valueSet) =>
+ makeInSet.lift(dataType).map(_(name, valueSet))
+
case _ => None
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index fcb9513ab66f..704cf56f3826 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -18,20 +18,21 @@
package org.apache.spark.sql.parquet
import java.io.IOException
-import java.util.logging.Level
+import java.util.logging.{Level, Logger => JLogger}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.permission.FsAction
-import org.apache.spark.sql.types.{StructType, DataType}
-import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
-import parquet.hadoop.metadata.CompressionCodecName
-import parquet.schema.MessageType
+import org.apache.parquet.hadoop.metadata.CompressionCodecName
+import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat, ParquetRecordReader}
+import org.apache.parquet.schema.MessageType
+import org.apache.parquet.{Log => ParquetLog}
-import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException}
-import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, SQLContext}
/**
* Relation that consists of data stored in a Parquet columnar format.
@@ -94,40 +95,44 @@ private[sql] case class ParquetRelation(
private[sql] object ParquetRelation {
def enableLogForwarding() {
- // Note: the parquet.Log class has a static initializer that
- // sets the java.util.logging Logger for "parquet". This
+ // Note: the org.apache.parquet.Log class has a static initializer that
+ // sets the java.util.logging Logger for "org.apache.parquet". This
// checks first to see if there's any handlers already set
// and if not it creates them. If this method executes prior
// to that class being loaded then:
// 1) there's no handlers installed so there's none to
// remove. But when it IS finally loaded the desired affect
// of removing them is circumvented.
- // 2) The parquet.Log static initializer calls setUseParentHanders(false)
+ // 2) The parquet.Log static initializer calls setUseParentHandlers(false)
// undoing the attempt to override the logging here.
//
// Therefore we need to force the class to be loaded.
// This should really be resolved by Parquet.
- Class.forName(classOf[parquet.Log].getName)
+ Class.forName(classOf[ParquetLog].getName)
// Note: Logger.getLogger("parquet") has a default logger
// that appends to Console which needs to be cleared.
- val parquetLogger = java.util.logging.Logger.getLogger("parquet")
+ val parquetLogger = JLogger.getLogger(classOf[ParquetLog].getPackage.getName)
parquetLogger.getHandlers.foreach(parquetLogger.removeHandler)
- // TODO(witgo): Need to set the log level ?
- // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null)
- if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true)
+ parquetLogger.setUseParentHandlers(true)
- // Disables WARN log message in ParquetOutputCommitter.
+ // Disables a WARN log message in ParquetOutputCommitter. We first ensure that
+ // ParquetOutputCommitter is loaded and the static LOG field gets initialized.
// See https://issues.apache.org/jira/browse/SPARK-5968 for details
Class.forName(classOf[ParquetOutputCommitter].getName)
- java.util.logging.Logger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF)
+ JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF)
+
+ // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader.
+ // See https://issues.apache.org/jira/browse/PARQUET-220 for details
+ Class.forName(classOf[ParquetRecordReader[_]].getName)
+ JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF)
}
// The element type for the RDDs that this relation maps to.
type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow
// The compression type
- type CompressionType = parquet.hadoop.metadata.CompressionCodecName
+ type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName
// The parquet compression short names
val shortParquetCompressionCodecNames = Map(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index cb7ae246d0d7..39360e13313a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -33,20 +33,20 @@ import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path}
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat}
-import parquet.hadoop._
-import parquet.hadoop.api.ReadSupport.ReadContext
-import parquet.hadoop.api.{InitContext, ReadSupport}
-import parquet.hadoop.metadata.GlobalMetaData
-import parquet.hadoop.util.ContextUtil
-import parquet.io.ParquetDecodingException
-import parquet.schema.MessageType
+import org.apache.parquet.hadoop._
+import org.apache.parquet.hadoop.api.ReadSupport.ReadContext
+import org.apache.parquet.hadoop.api.{InitContext, ReadSupport}
+import org.apache.parquet.hadoop.metadata.GlobalMetaData
+import org.apache.parquet.hadoop.util.ContextUtil
+import org.apache.parquet.io.ParquetDecodingException
+import org.apache.parquet.schema.MessageType
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
import org.apache.spark.sql.types.StructType
import org.apache.spark.{Logging, SerializableWritable, TaskContext}
@@ -54,7 +54,7 @@ import org.apache.spark.{Logging, SerializableWritable, TaskContext}
/**
* :: DeveloperApi ::
* Parquet table scan operator. Imports the file that backs the given
- * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``.
+ * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[InternalRow]``.
*/
private[sql] case class ParquetTableScan(
attributes: Seq[Attribute],
@@ -77,8 +77,8 @@ private[sql] case class ParquetTableScan(
}
}.toArray
- protected override def doExecute(): RDD[Row] = {
- import parquet.filter2.compat.FilterCompat.FilterPredicateCompat
+ protected override def doExecute(): RDD[InternalRow] = {
+ import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat
val sc = sqlContext.sparkContext
val job = new Job(sc.hadoopConfiguration)
@@ -117,12 +117,15 @@ private[sql] case class ParquetTableScan(
SQLConf.PARQUET_CACHE_METADATA,
sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true"))
+ // Use task side metadata in parquet
+ conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true);
+
val baseRDD =
new org.apache.spark.rdd.NewHadoopRDD(
sc,
classOf[FilteringParquetRowInputFormat],
classOf[Void],
- classOf[Row],
+ classOf[InternalRow],
conf)
if (requestedPartitionOrdinals.nonEmpty) {
@@ -136,7 +139,7 @@ private[sql] case class ParquetTableScan(
baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
val partValue = "([^=]+)=([^=]+)".r
val partValues =
- split.asInstanceOf[parquet.hadoop.ParquetInputSplit]
+ split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit]
.getPath
.toString
.split("/")
@@ -151,9 +154,9 @@ private[sql] case class ParquetTableScan(
.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
if (primitiveRow) {
- new Iterator[Row] {
+ new Iterator[InternalRow] {
def hasNext: Boolean = iter.hasNext
- def next(): Row = {
+ def next(): InternalRow = {
// We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow.
val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
@@ -170,12 +173,12 @@ private[sql] case class ParquetTableScan(
} else {
// Create a mutable row since we need to fill in values from partition columns.
val mutableRow = new GenericMutableRow(outputSize)
- new Iterator[Row] {
+ new Iterator[InternalRow] {
def hasNext: Boolean = iter.hasNext
- def next(): Row = {
+ def next(): InternalRow = {
// We are using CatalystGroupConverter and it returns a GenericRow.
// Since GenericRow is not mutable, we just cast it to a Row.
- val row = iter.next()._2.asInstanceOf[Row]
+ val row = iter.next()._2.asInstanceOf[InternalRow]
var i = 0
while (i < row.size) {
@@ -255,7 +258,7 @@ private[sql] case class InsertIntoParquetTable(
/**
* Inserts all rows into the Parquet file.
*/
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
// TODO: currently we do not check whether the "schema"s are compatible
// That means if one first creates a table and then INSERTs data with
// and incompatible schema the execution will fail. It would be nice
@@ -318,13 +321,13 @@ private[sql] case class InsertIntoParquetTable(
* @param conf A [[org.apache.hadoop.conf.Configuration]].
*/
private def saveAsHadoopFile(
- rdd: RDD[Row],
+ rdd: RDD[InternalRow],
path: String,
conf: Configuration) {
val job = new Job(conf)
val keyType = classOf[Void]
job.setOutputKeyClass(keyType)
- job.setOutputValueClass(classOf[Row])
+ job.setOutputValueClass(classOf[InternalRow])
NewFileOutputFormat.setOutputPath(job, new Path(path))
val wrappedConf = new SerializableWritable(job.getConfiguration)
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
@@ -339,7 +342,7 @@ private[sql] case class InsertIntoParquetTable(
.findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1
}
- def writeShard(context: TaskContext, iter: Iterator[Row]): Int = {
+ def writeShard(context: TaskContext, iter: Iterator[InternalRow]): Int = {
/* "reduce task" */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
context.attemptNumber)
@@ -378,7 +381,7 @@ private[sql] case class InsertIntoParquetTable(
* to imported ones.
*/
private[parquet] class AppendingParquetOutputFormat(offset: Int)
- extends parquet.hadoop.ParquetOutputFormat[Row] {
+ extends org.apache.parquet.hadoop.ParquetOutputFormat[InternalRow] {
// override to accept existing directories as valid output directory
override def checkOutputSpecs(job: JobContext): Unit = {}
var committer: OutputCommitter = null
@@ -431,210 +434,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int)
* RecordFilter we want to use.
*/
private[parquet] class FilteringParquetRowInputFormat
- extends parquet.hadoop.ParquetInputFormat[Row] with Logging {
+ extends org.apache.parquet.hadoop.ParquetInputFormat[InternalRow] with Logging {
private var fileStatuses = Map.empty[Path, FileStatus]
override def createRecordReader(
inputSplit: InputSplit,
- taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
+ taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = {
- import parquet.filter2.compat.FilterCompat.NoOpFilter
+ import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter
- val readSupport: ReadSupport[Row] = new RowReadSupport()
+ val readSupport: ReadSupport[InternalRow] = new RowReadSupport()
val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext))
if (!filter.isInstanceOf[NoOpFilter]) {
- new ParquetRecordReader[Row](
+ new ParquetRecordReader[InternalRow](
readSupport,
filter)
} else {
- new ParquetRecordReader[Row](readSupport)
- }
- }
-
- // This is only a temporary solution sicne we need to use fileStatuses in
- // both getClientSideSplits and getTaskSideSplits. It can be removed once we get rid of these
- // two methods.
- override def getSplits(jobContext: JobContext): JList[InputSplit] = {
- // First set fileStatuses.
- val statuses = listStatus(jobContext)
- fileStatuses = statuses.map(file => file.getPath -> file).toMap
-
- super.getSplits(jobContext)
- }
-
- // TODO Remove this method and related code once PARQUET-16 is fixed
- // This method together with the `getFooters` method and the `fileStatuses` field are just used
- // to mimic this PR: https://github.com/apache/incubator-parquet-mr/pull/17
- override def getSplits(
- configuration: Configuration,
- footers: JList[Footer]): JList[ParquetInputSplit] = {
-
- // Use task side strategy by default
- val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true)
- val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue)
- val minSplitSize: JLong =
- Math.max(getFormatMinSplitSize, configuration.getLong("mapred.min.split.size", 0L))
- if (maxSplitSize < 0 || minSplitSize < 0) {
- throw new ParquetDecodingException(
- s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" +
- s" minSplitSize = $minSplitSize")
- }
-
- // Uses strict type checking by default
- val getGlobalMetaData =
- classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]])
- getGlobalMetaData.setAccessible(true)
- var globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData]
-
- if (globalMetaData == null) {
- val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
- return splits
- }
-
- val metadata = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA)
- val mergedMetadata = globalMetaData
- .getKeyValueMetaData
- .updated(RowReadSupport.SPARK_METADATA_KEY, setAsJavaSet(Set(metadata)))
-
- globalMetaData = new GlobalMetaData(globalMetaData.getSchema,
- mergedMetadata, globalMetaData.getCreatedBy)
-
- val readContext = getReadSupport(configuration).init(
- new InitContext(configuration,
- globalMetaData.getKeyValueMetaData,
- globalMetaData.getSchema))
-
- if (taskSideMetaData){
- logInfo("Using Task Side Metadata Split Strategy")
- getTaskSideSplits(configuration,
- footers,
- maxSplitSize,
- minSplitSize,
- readContext)
- } else {
- logInfo("Using Client Side Metadata Split Strategy")
- getClientSideSplits(configuration,
- footers,
- maxSplitSize,
- minSplitSize,
- readContext)
+ new ParquetRecordReader[InternalRow](readSupport)
}
-
- }
-
- def getClientSideSplits(
- configuration: Configuration,
- footers: JList[Footer],
- maxSplitSize: JLong,
- minSplitSize: JLong,
- readContext: ReadContext): JList[ParquetInputSplit] = {
-
- import parquet.filter2.compat.FilterCompat.Filter
- import parquet.filter2.compat.RowGroupFilter
-
- import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache
-
- val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true)
-
- val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
- val filter: Filter = ParquetInputFormat.getFilter(configuration)
- var rowGroupsDropped: Long = 0
- var totalRowGroups: Long = 0
-
- // Ugly hack, stuck with it until PR:
- // https://github.com/apache/incubator-parquet-mr/pull/17
- // is resolved
- val generateSplits =
- Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy")
- .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse(
- sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits"))
- generateSplits.setAccessible(true)
-
- for (footer <- footers) {
- val fs = footer.getFile.getFileSystem(configuration)
- val file = footer.getFile
- val status = fileStatuses.getOrElse(file, fs.getFileStatus(file))
- val parquetMetaData = footer.getParquetMetadata
- val blocks = parquetMetaData.getBlocks
- totalRowGroups = totalRowGroups + blocks.size
- val filteredBlocks = RowGroupFilter.filterRowGroups(
- filter,
- blocks,
- parquetMetaData.getFileMetaData.getSchema)
- rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size)
-
- if (!filteredBlocks.isEmpty){
- var blockLocations: Array[BlockLocation] = null
- if (!cacheMetadata) {
- blockLocations = fs.getFileBlockLocations(status, 0, status.getLen)
- } else {
- blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] {
- def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen)
- })
- }
- splits.addAll(
- generateSplits.invoke(
- null,
- filteredBlocks,
- blockLocations,
- status,
- readContext.getRequestedSchema.toString,
- readContext.getReadSupportMetadata,
- minSplitSize,
- maxSplitSize).asInstanceOf[JList[ParquetInputSplit]])
- }
- }
-
- if (rowGroupsDropped > 0 && totalRowGroups > 0){
- val percentDropped = ((rowGroupsDropped/totalRowGroups.toDouble) * 100).toInt
- logInfo(s"Dropping $rowGroupsDropped row groups that do not pass filter predicate "
- + s"($percentDropped %) !")
- }
- else {
- logInfo("There were no row groups that could be dropped due to filter predicates")
- }
- splits
-
- }
-
- def getTaskSideSplits(
- configuration: Configuration,
- footers: JList[Footer],
- maxSplitSize: JLong,
- minSplitSize: JLong,
- readContext: ReadContext): JList[ParquetInputSplit] = {
-
- val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
-
- // Ugly hack, stuck with it until PR:
- // https://github.com/apache/incubator-parquet-mr/pull/17
- // is resolved
- val generateSplits =
- Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy")
- .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse(
- sys.error(
- s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits"))
- generateSplits.setAccessible(true)
-
- for (footer <- footers) {
- val file = footer.getFile
- val fs = file.getFileSystem(configuration)
- val status = fileStatuses.getOrElse(file, fs.getFileStatus(file))
- val blockLocations = fs.getFileBlockLocations(status, 0, status.getLen)
- splits.addAll(
- generateSplits.invoke(
- null,
- blockLocations,
- status,
- readContext.getRequestedSchema.toString,
- readContext.getReadSupportMetadata,
- minSplitSize,
- maxSplitSize).asInstanceOf[JList[ParquetInputSplit]])
- }
-
- splits
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 70a220cc43ab..a8775a2a8fd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -20,16 +20,18 @@ package org.apache.spark.sql.parquet
import java.util.{HashMap => JHashMap}
import org.apache.hadoop.conf.Configuration
-import parquet.column.ParquetProperties
-import parquet.hadoop.ParquetOutputFormat
-import parquet.hadoop.api.ReadSupport.ReadContext
-import parquet.hadoop.api.{ReadSupport, WriteSupport}
-import parquet.io.api._
-import parquet.schema.MessageType
+import org.apache.parquet.column.ParquetProperties
+import org.apache.parquet.hadoop.ParquetOutputFormat
+import org.apache.parquet.hadoop.api.ReadSupport.ReadContext
+import org.apache.parquet.hadoop.api.{ReadSupport, WriteSupport}
+import org.apache.parquet.io.api._
+import org.apache.parquet.schema.MessageType
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow}
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* A `parquet.io.api.RecordMaterializer` for Rows.
@@ -37,12 +39,12 @@ import org.apache.spark.sql.types._
*@param root The root group converter for the record.
*/
private[parquet] class RowRecordMaterializer(root: CatalystConverter)
- extends RecordMaterializer[Row] {
+ extends RecordMaterializer[InternalRow] {
def this(parquetSchema: MessageType, attributes: Seq[Attribute]) =
this(CatalystConverter.createRootConverter(parquetSchema, attributes))
- override def getCurrentRecord: Row = root.getCurrentRecord
+ override def getCurrentRecord: InternalRow = root.getCurrentRecord
override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter]
}
@@ -50,13 +52,13 @@ private[parquet] class RowRecordMaterializer(root: CatalystConverter)
/**
* A `parquet.hadoop.api.ReadSupport` for Row objects.
*/
-private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging {
+private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logging {
override def prepareForRead(
conf: Configuration,
stringMap: java.util.Map[String, String],
fileSchema: MessageType,
- readContext: ReadContext): RecordMaterializer[Row] = {
+ readContext: ReadContext): RecordMaterializer[InternalRow] = {
log.debug(s"preparing for read with Parquet file schema $fileSchema")
// Note: this very much imitates AvroParquet
val parquetSchema = readContext.getRequestedSchema
@@ -131,7 +133,7 @@ private[parquet] object RowReadSupport {
/**
* A `parquet.hadoop.api.WriteSupport` for Row objects.
*/
-private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
+private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Logging {
private[parquet] var writer: RecordConsumer = null
private[parquet] var attributes: Array[Attribute] = null
@@ -155,7 +157,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
log.debug(s"preparing for write with schema $attributes")
}
- override def write(record: Row): Unit = {
+ override def write(record: InternalRow): Unit = {
val attributesSize = attributes.size
if (attributesSize > record.size) {
throw new IndexOutOfBoundsException(
@@ -204,7 +206,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
case ShortType => writer.addInteger(value.asInstanceOf[Short])
case LongType => writer.addLong(value.asInstanceOf[Long])
- case TimestampType => writeTimestamp(value.asInstanceOf[java.sql.Timestamp])
+ case TimestampType => writeTimestamp(value.asInstanceOf[Long])
case ByteType => writer.addInteger(value.asInstanceOf[Byte])
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
@@ -311,15 +313,16 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
}
- private[parquet] def writeTimestamp(ts: java.sql.Timestamp): Unit = {
- val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(ts)
+ private[parquet] def writeTimestamp(ts: Long): Unit = {
+ val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(
+ DateUtils.toJavaTimestamp(ts))
writer.addBinary(binaryNanoTime)
}
}
// Optimized for non-nested rows
private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
- override def write(record: Row): Unit = {
+ override def write(record: InternalRow): Unit = {
val attributesSize = attributes.size
if (attributesSize > record.size) {
throw new IndexOutOfBoundsException(
@@ -342,7 +345,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
private def consumeType(
ctype: DataType,
- record: Row,
+ record: InternalRow,
index: Int): Unit = {
ctype match {
case StringType => writer.addBinary(
@@ -357,7 +360,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
case FloatType => writer.addFloat(record.getFloat(index))
case BooleanType => writer.addBoolean(record.getBoolean(index))
case DateType => writer.addInteger(record.getInt(index))
- case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp])
+ case TimestampType => writeTimestamp(record(index).asInstanceOf[Long])
case d: DecimalType =>
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index 6698b19c7477..ba2a35b74ef8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -19,26 +19,25 @@ package org.apache.spark.sql.parquet
import java.io.IOException
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.JavaConversions._
import scala.util.Try
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
-import parquet.format.converter.ParquetMetadataConverter
-import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
-import parquet.hadoop.util.ContextUtil
-import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter}
-import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
-import parquet.schema.Type.Repetition
-import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes}
+import org.apache.parquet.format.converter.ParquetMetadataConverter
+import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
+import org.apache.parquet.hadoop.util.ContextUtil
+import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter}
+import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
+import org.apache.parquet.schema.Type.Repetition
+import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes}
+import org.apache.spark.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types._
-import org.apache.spark.{Logging, SparkException}
-// Implicits
-import scala.collection.JavaConversions._
/** A class representing Parquet info fields we care about, for passing back to Parquet */
private[parquet] case class ParquetTypeInfo(
@@ -73,13 +72,12 @@ private[parquet] object ParquetTypesConverter extends Logging {
case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType
case ParquetPrimitiveTypeName.INT96 =>
// TODO: add BigInteger type? TODO(andre) use DecimalType instead????
- sys.error("Potential loss of precision: cannot convert INT96")
+ throw new AnalysisException("Potential loss of precision: cannot convert INT96")
case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY
if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) =>
// TODO: for now, our reader only supports decimals that fit in a Long
DecimalType(decimalInfo.getPrecision, decimalInfo.getScale)
- case _ => sys.error(
- s"Unsupported parquet datatype $parquetType")
+ case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType")
}
}
@@ -371,7 +369,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
parquetKeyType,
parquetValueType)
}
- case _ => sys.error(s"Unsupported datatype $ctype")
+ case _ => throw new AnalysisException(s"Unsupported datatype $ctype")
}
}
}
@@ -403,7 +401,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
def convertFromString(string: String): Seq[Attribute] = {
Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match {
case s: StructType => s.toAttributes
- case other => sys.error(s"Can convert $string to row")
+ case other => throw new AnalysisException(s"Can convert $string to row")
}
}
@@ -411,8 +409,8 @@ private[parquet] object ParquetTypesConverter extends Logging {
// ,;{}()\n\t= and space character are special characters in Parquet schema
schema.map(_.name).foreach { name =>
if (name.matches(".*[ ,;{}()\n\t=].*")) {
- sys.error(
- s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=".
+ throw new AnalysisException(
+ s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=".
|Please use alias to rename it.
""".stripMargin.split("\n").mkString(" "))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 8b3e1b2b59bf..bba6f1ec96aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -29,20 +29,21 @@ import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
-import parquet.filter2.predicate.FilterApi
-import parquet.hadoop._
-import parquet.hadoop.metadata.CompressionCodecName
-import parquet.hadoop.util.ContextUtil
+import org.apache.parquet.filter2.predicate.FilterApi
+import org.apache.parquet.hadoop._
+import org.apache.parquet.hadoop.metadata.CompressionCodecName
+import org.apache.parquet.hadoop.util.ContextUtil
-import org.apache.spark.{Partition => SparkPartition, SerializableWritable, Logging, SparkException}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.rdd.RDD._
import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.RDD._
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.sql.{Row, SQLConf, SQLContext}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SerializableWritable, SparkException, Partition => SparkPartition}
private[sql] class DefaultSource extends HadoopFsRelationProvider {
override def createRelation(
@@ -59,7 +60,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider {
private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext)
extends OutputWriter {
- private val recordWriter: RecordWriter[Void, Row] = {
+ private val recordWriter: RecordWriter[Void, InternalRow] = {
val conf = context.getConfiguration
val outputFormat = {
// When appending new Parquet files to an existing Parquet file directory, to avoid
@@ -83,7 +84,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext
case partFilePattern(id) => id.toInt
case name if name.startsWith("_") => 0
case name if name.startsWith(".") => 0
- case name => sys.error(
+ case name => throw new AnalysisException(
s"Trying to write Parquet files to directory $outputPath, " +
s"but found items with illegal name '$name'.")
}.reduceOption(_ max _).getOrElse(0)
@@ -92,7 +93,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext
}
}
- new ParquetOutputFormat[Row]() {
+ new ParquetOutputFormat[InternalRow]() {
// Here we override `getDefaultWorkFile` for two reasons:
//
// 1. To allow appending. We need to generate output file name based on the max available
@@ -111,7 +112,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext
outputFormat.getRecordWriter(context)
}
- override def write(row: Row): Unit = recordWriter.write(null, row)
+ override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow])
override def close(): Unit = recordWriter.close(context)
}
@@ -155,7 +156,7 @@ private[sql] class ParquetRelation2(
meta
}
- override def equals(other: scala.Any): Boolean = other match {
+ override def equals(other: Any): Boolean = other match {
case that: ParquetRelation2 =>
val schemaEquality = if (shouldMergeSchemas) {
this.shouldMergeSchemas == that.shouldMergeSchemas
@@ -190,7 +191,7 @@ private[sql] class ParquetRelation2(
}
}
- override def dataSchema: StructType = metadataCache.dataSchema
+ override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema)
override private[sql] def refresh(): Unit = {
super.refresh()
@@ -211,6 +212,13 @@ private[sql] class ParquetRelation2(
classOf[ParquetOutputCommitter],
classOf[ParquetOutputCommitter])
+ if (conf.get("spark.sql.parquet.output.committer.class") == null) {
+ logInfo("Using default output committer for Parquet: " +
+ classOf[ParquetOutputCommitter].getCanonicalName)
+ } else {
+ logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName)
+ }
+
conf.setClass(
SQLConf.OUTPUT_COMMITTER_CLASS,
committerClass,
@@ -278,7 +286,7 @@ private[sql] class ParquetRelation2(
initLocalJobFuncOpt = Some(initLocalJobFuncOpt),
inputFormatClass = classOf[FilteringParquetRowInputFormat],
keyClass = classOf[Void],
- valueClass = classOf[Row]) {
+ valueClass = classOf[InternalRow]) {
val cacheMetadata = useMetadataCache
@@ -323,7 +331,7 @@ private[sql] class ParquetRelation2(
new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
}
}
- }.values
+ }.values.map(_.asInstanceOf[Row])
}
}
@@ -380,11 +388,12 @@ private[sql] class ParquetRelation2(
// time-consuming.
if (dataSchema == null) {
dataSchema = {
- val dataSchema0 =
- maybeDataSchema
- .orElse(readSchema())
- .orElse(maybeMetastoreSchema)
- .getOrElse(sys.error("Failed to get the schema."))
+ val dataSchema0 = maybeDataSchema
+ .orElse(readSchema())
+ .orElse(maybeMetastoreSchema)
+ .getOrElse(throw new AnalysisException(
+ s"Failed to discover schema of Parquet file(s) in the following location(s):\n" +
+ paths.mkString("\n\t")))
// If this Parquet relation is converted from a Hive Metastore table, must reconcile case
// case insensitivity issue and possible schema mismatch (probably caused by schema
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala
index 70bcca7526aa..4d5ed211ad0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.parquet.timestamp
import java.nio.{ByteBuffer, ByteOrder}
-import parquet.Preconditions
-import parquet.io.api.{Binary, RecordConsumer}
+import org.apache.parquet.Preconditions
+import org.apache.parquet.io.api.{Binary, RecordConsumer}
private[parquet] class NanoTime extends Serializable {
private var julianDay = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index c6a4dabbab05..4cf67439b9b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -17,47 +17,48 @@
package org.apache.spark.sql.sources
-import org.apache.spark.{Logging, SerializableWritable, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{StringType, StructType, UTF8String}
+import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{SaveMode, Strategy, execution, sources}
import org.apache.spark.util.Utils
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.{Logging, SerializableWritable, TaskContext}
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
private[sql] object DataSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
- case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) =>
+ case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan)) =>
pruneFilterProjectRaw(
l,
- projectList,
+ projects,
filters,
- (a, f) => t.buildScan(a, f)) :: Nil
+ (a, f) => toCatalystRDD(l, a, t.buildScan(a, f))) :: Nil
- case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) =>
+ case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan)) =>
pruneFilterProject(
l,
- projectList,
+ projects,
filters,
- (a, f) => t.buildScan(a, f)) :: Nil
+ (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil
- case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) =>
+ case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan)) =>
pruneFilterProject(
l,
- projectList,
+ projects,
filters,
- (a, _) => t.buildScan(a)) :: Nil
+ (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil
// Scanning partitioned HadoopFsRelation
- case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation))
+ case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation))
if t.partitionSpec.partitionColumns.nonEmpty =>
val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray
@@ -79,13 +80,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
buildPartitionedTableScan(
l,
- projectList,
+ projects,
pushedFilters,
t.partitionSpec.partitionColumns,
selectedPartitions) :: Nil
// Scanning non-partitioned HadoopFsRelation
- case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) =>
+ case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) =>
// See buildPartitionedTableScan for the reason that we need to create a shard
// broadcast HadoopConf.
val sharedHadoopConf = SparkHadoopUtil.get.conf
@@ -93,12 +94,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
t.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf))
pruneFilterProject(
l,
- projectList,
+ projects,
filters,
- (a, f) => t.buildScan(a, f, t.paths, confBroadcast)) :: Nil
+ (a, f) =>
+ toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f, t.paths, confBroadcast))) :: Nil
case l @ LogicalRelation(t: TableScan) =>
- createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil
+ execution.PhysicalRDD(l.output, toCatalystRDD(l, t.buildScan())) :: Nil
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty =>
@@ -118,7 +120,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
filters: Seq[Expression],
partitionColumns: StructType,
partitions: Array[Partition]) = {
- val output = projections.map(_.toAttribute)
val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation]
// Because we are creating one RDD per partition, we need to have a shared HadoopConf.
@@ -137,23 +138,23 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
logicalRelation,
projections,
filters,
- (requiredColumns, filters) => {
+ (columns: Seq[Attribute], filters) => {
val partitionColNames = partitionColumns.fieldNames
// Don't scan any partition columns to save I/O. Here we are being optimistic and
// assuming partition columns data stored in data files are always consistent with those
// partition values encoded in partition directory paths.
- val nonPartitionColumns = requiredColumns.filterNot(partitionColNames.contains)
+ val needed = columns.filterNot(a => partitionColNames.contains(a.name))
val dataRows =
- relation.buildScan(nonPartitionColumns, filters, Array(dir), confBroadcast)
+ relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast)
// Merges data values with partition values.
mergeWithPartitionValues(
relation.schema,
- requiredColumns,
+ columns.map(_.name).toArray,
partitionColNames,
partitionValues,
- dataRows)
+ toCatalystRDD(logicalRelation, needed, dataRows))
})
scan.execute()
@@ -166,15 +167,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows)
}
- createPhysicalRDD(logicalRelation.relation, output, unionedRows)
+ execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows)
}
private def mergeWithPartitionValues(
schema: StructType,
requiredColumns: Array[String],
partitionColumns: Array[String],
- partitionValues: Row,
- dataRows: RDD[Row]): RDD[Row] = {
+ partitionValues: InternalRow,
+ dataRows: RDD[InternalRow]): RDD[InternalRow] = {
val nonPartitionColumns = requiredColumns.filterNot(partitionColumns.contains)
// If output columns contain any partition column(s), we need to merge scanned data
@@ -185,13 +186,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
val i = partitionColumns.indexOf(name)
if (i != -1) {
// If yes, gets column value from partition values.
- (mutableRow: MutableRow, dataRow: expressions.Row, ordinal: Int) => {
+ (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
mutableRow(ordinal) = partitionValues(i)
}
} else {
// Otherwise, inherits the value from scanned data.
val i = nonPartitionColumns.indexOf(name)
- (mutableRow: MutableRow, dataRow: expressions.Row, ordinal: Int) => {
+ (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
mutableRow(ordinal) = dataRow(i)
}
}
@@ -200,7 +201,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// Since we know for sure that this closure is serializable, we can avoid the overhead
// of cleaning a closure for each RDD by creating our own MapPartitionsRDD. Functionally
// this is equivalent to calling `dataRows.mapPartitions(mapPartitionsFunc)` (SPARK-7718).
- val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Row]) => {
+ val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => {
val dataTypes = requiredColumns.map(schema(_).dataType)
val mutableRow = new SpecificMutableRow(dataTypes)
iterator.map { dataRow =>
@@ -209,7 +210,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
mergers(i)(mutableRow, dataRow, i)
i += 1
}
- mutableRow.asInstanceOf[expressions.Row]
+ mutableRow.asInstanceOf[InternalRow]
}
}
@@ -255,26 +256,26 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// Based on Public API.
protected def pruneFilterProject(
relation: LogicalRelation,
- projectList: Seq[NamedExpression],
+ projects: Seq[NamedExpression],
filterPredicates: Seq[Expression],
- scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = {
+ scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow]) = {
pruneFilterProjectRaw(
relation,
- projectList,
+ projects,
filterPredicates,
(requestedColumns, pushedFilters) => {
- scanBuilder(requestedColumns.map(_.name).toArray, selectFilters(pushedFilters).toArray)
+ scanBuilder(requestedColumns, selectFilters(pushedFilters).toArray)
})
}
// Based on Catalyst expressions.
protected def pruneFilterProjectRaw(
relation: LogicalRelation,
- projectList: Seq[NamedExpression],
+ projects: Seq[NamedExpression],
filterPredicates: Seq[Expression],
- scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[Row]) = {
+ scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[InternalRow]) = {
- val projectSet = AttributeSet(projectList.flatMap(_.references))
+ val projectSet = AttributeSet(projects.flatMap(_.references))
val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
val filterCondition = filterPredicates.reduceLeftOption(expressions.And)
@@ -282,38 +283,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes.
}}
- if (projectList.map(_.toAttribute) == projectList &&
- projectSet.size == projectList.size &&
+ if (projects.map(_.toAttribute) == projects &&
+ projectSet.size == projects.size &&
filterSet.subsetOf(projectSet)) {
// When it is possible to just use column pruning to get the right projection and
// when the columns of this projection are enough to evaluate all filter conditions,
// just do a scan followed by a filter, with no extra project.
val requestedColumns =
- projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above.
+ projects.asInstanceOf[Seq[Attribute]] // Safe due to if above.
.map(relation.attributeMap) // Match original case of attributes.
- val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute),
- scanBuilder(requestedColumns, pushedFilters))
+ val scan = execution.PhysicalRDD(projects.map(_.toAttribute),
+ scanBuilder(requestedColumns, pushedFilters))
filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
} else {
val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq
- val scan = createPhysicalRDD(relation.relation, requestedColumns,
+ val scan = execution.PhysicalRDD(requestedColumns,
scanBuilder(requestedColumns, pushedFilters))
- execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
+ execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
}
}
- private[this] def createPhysicalRDD(
- relation: BaseRelation,
+ /**
+ * Convert RDD of Row into RDD of InternalRow with objects in catalyst types
+ */
+ private[this] def toCatalystRDD(
+ relation: LogicalRelation,
output: Seq[Attribute],
- rdd: RDD[Row]): SparkPlan = {
- val converted = if (relation.needConversion) {
- execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
+ rdd: RDD[Row]): RDD[InternalRow] = {
+ if (relation.relation.needConversion) {
+ execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType))
} else {
- rdd
+ rdd.map(_.asInstanceOf[InternalRow])
}
- execution.PhysicalRDD(output, converted)
+ }
+
+ /**
+ * Convert RDD of Row into RDD of InternalRow with objects in catalyst types
+ */
+ private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = {
+ toCatalystRDD(relation, relation.output, rdd)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala
index dafdf0f8b456..c6f535dde767 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources
-import java.lang.{Double => JDouble, Float => JFloat, Long => JLong}
+import java.lang.{Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import scala.collection.mutable.ArrayBuffer
@@ -25,12 +25,11 @@ import scala.util.Try
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.Shell
-
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types._
-private[sql] case class Partition(values: Row, path: String)
+private[sql] case class Partition(values: InternalRow, path: String)
private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition])
@@ -72,10 +71,11 @@ private[sql] object PartitioningUtils {
*/
private[sql] def parsePartitions(
paths: Seq[Path],
- defaultPartitionName: String): PartitionSpec = {
+ defaultPartitionName: String,
+ typeInference: Boolean): PartitionSpec = {
// First, we need to parse every partition's path and see if we can find partition values.
val pathsWithPartitionValues = paths.flatMap { path =>
- parsePartition(path, defaultPartitionName).map(path -> _)
+ parsePartition(path, defaultPartitionName, typeInference).map(path -> _)
}
if (pathsWithPartitionValues.isEmpty) {
@@ -99,7 +99,7 @@ private[sql] object PartitioningUtils {
// Finally, we create `Partition`s based on paths and resolved partition values.
val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map {
case (PartitionValues(_, literals), (path, _)) =>
- Partition(Row.fromSeq(literals.map(_.value)), path.toString)
+ Partition(InternalRow.fromSeq(literals.map(_.value)), path.toString)
}
PartitionSpec(StructType(fields), partitions)
@@ -124,7 +124,8 @@ private[sql] object PartitioningUtils {
*/
private[sql] def parsePartition(
path: Path,
- defaultPartitionName: String): Option[PartitionValues] = {
+ defaultPartitionName: String,
+ typeInference: Boolean): Option[PartitionValues] = {
val columns = ArrayBuffer.empty[(String, Literal)]
// Old Hadoop versions don't have `Path.isRoot`
var finished = path.getParent == null
@@ -137,7 +138,7 @@ private[sql] object PartitioningUtils {
return None
}
- val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName)
+ val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference)
maybeColumn.foreach(columns += _)
chopped = chopped.getParent
finished = maybeColumn.isEmpty || chopped.getParent == null
@@ -153,7 +154,8 @@ private[sql] object PartitioningUtils {
private def parsePartitionColumn(
columnSpec: String,
- defaultPartitionName: String): Option[(String, Literal)] = {
+ defaultPartitionName: String,
+ typeInference: Boolean): Option[(String, Literal)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
None
@@ -164,7 +166,7 @@ private[sql] object PartitioningUtils {
val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")
- val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName)
+ val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName, typeInference)
Some(columnName -> literal)
}
}
@@ -175,7 +177,7 @@ private[sql] object PartitioningUtils {
* {{{
* NullType ->
* IntegerType -> LongType ->
- * FloatType -> DoubleType -> DecimalType.Unlimited ->
+ * DoubleType -> DecimalType.Unlimited ->
* StringType
* }}}
*/
@@ -187,7 +189,7 @@ private[sql] object PartitioningUtils {
Seq.empty
} else {
assert(distinctPartitionsColNames.size == 1, {
- val list = distinctPartitionsColNames.mkString("\t", "\n", "")
+ val list = distinctPartitionsColNames.mkString("\t", "\n\t", "")
s"Conflicting partition column names detected:\n$list"
})
@@ -205,25 +207,36 @@ private[sql] object PartitioningUtils {
}
/**
- * Converts a string to a `Literal` with automatic type inference. Currently only supports
- * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and
+ * Converts a string to a [[Literal]] with automatic type inference. Currently only supports
+ * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and
* [[StringType]].
*/
private[sql] def inferPartitionColumnValue(
raw: String,
- defaultPartitionName: String): Literal = {
- // First tries integral types
- Try(Literal.create(Integer.parseInt(raw), IntegerType))
- .orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
- // Then falls back to fractional types
- .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType)))
- .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
- .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited)))
- // Then falls back to string
- .getOrElse {
- if (raw == defaultPartitionName) Literal.create(null, NullType)
- else Literal.create(unescapePathName(raw), StringType)
+ defaultPartitionName: String,
+ typeInference: Boolean): Literal = {
+ if (typeInference) {
+ // First tries integral types
+ Try(Literal.create(Integer.parseInt(raw), IntegerType))
+ .orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
+ // Then falls back to fractional types
+ .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
+ .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited)))
+ // Then falls back to string
+ .getOrElse {
+ if (raw == defaultPartitionName) {
+ Literal.create(null, NullType)
+ } else {
+ Literal.create(unescapePathName(raw), StringType)
+ }
+ }
+ } else {
+ if (raw == defaultPartitionName) {
+ Literal.create(null, NullType)
+ } else {
+ Literal.create(unescapePathName(raw), StringType)
}
+ }
}
private val upCastingOrder: Seq[DataType] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
index a74a98631da3..ebad0c1564ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
@@ -216,7 +216,7 @@ private[sql] class SqlNewHadoopRDD[K, V](
override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = {
val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value
val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
- case Some(c) =>
+ case Some(c) =>
try {
val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
Some(HadoopRDD.convertSplitLocationInfo(infos))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index 3132067d562f..3dbe6faabf45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -23,19 +23,20 @@ import scala.collection.mutable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat}
-import parquet.hadoop.util.ContextUtil
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, FileOutputCommitter => MapReduceFileOutputCommitter}
+import org.apache.parquet.hadoop.util.ContextUtil
import org.apache.spark._
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode}
+import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode}
private[sql] case class InsertIntoDataSource(
logicalRelation: LogicalRelation,
@@ -43,18 +44,17 @@ private[sql] case class InsertIntoDataSource(
overwrite: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
val data = DataFrame(sqlContext, query)
// Apply the schema of the existing table to the new data.
- val df = sqlContext.createDataFrame(
- data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false)
+ val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite)
// Invalidate the cache.
sqlContext.cacheManager.invalidateCache(logicalRelation)
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
}
@@ -64,7 +64,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
mode: SaveMode)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
require(
relation.paths.length == 1,
s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")
@@ -89,15 +89,22 @@ private[sql] case class InsertIntoHadoopFsRelation(
if (doInsertion) {
val job = new Job(hadoopConf)
job.setOutputKeyClass(classOf[Void])
- job.setOutputValueClass(classOf[Row])
+ job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, qualifiedOutputPath)
// We create a DataFrame by applying the schema of relation to the data to make sure.
// We are writing data based on the expected schema,
- val df = sqlContext.createDataFrame(
- DataFrame(sqlContext, query).queryExecution.toRdd,
- relation.schema,
- needsConversion = false)
+ val df = {
+ // For partitioned relation r, r.schema's column ordering can be different from the column
+ // ordering of data.logicalPlan (partition columns are all moved after data column). We
+ // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can
+ // safely apply the schema of r.schema to the data.
+ val project = Project(
+ relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query)
+
+ sqlContext.internalCreateDataFrame(
+ DataFrame(sqlContext, project).queryExecution.toRdd, relation.schema)
+ }
val partitionColumns = relation.partitionColumns.fieldNames
if (partitionColumns.isEmpty) {
@@ -109,7 +116,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
}
}
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = {
@@ -117,8 +124,11 @@ private[sql] case class InsertIntoHadoopFsRelation(
val needsConversion = relation.needConversion
val dataSchema = relation.dataSchema
+ // This call shouldn't be put into the `try` block below because it only initializes and
+ // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
+ writerContainer.driverSideSetup()
+
try {
- writerContainer.driverSideSetup()
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
writerContainer.commitJob()
relation.refresh()
@@ -128,22 +138,21 @@ private[sql] case class InsertIntoHadoopFsRelation(
throw new SparkException("Job aborted.", cause)
}
- def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = {
- writerContainer.executorSideSetup(taskContext)
-
+ def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
+ // If anything below fails, we should abort the task.
try {
- if (needsConversion) {
- val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
- while (iterator.hasNext) {
- val row = converter(iterator.next()).asInstanceOf[Row]
- writerContainer.outputWriterForRow(row).write(row)
- }
+ writerContainer.executorSideSetup(taskContext)
+
+ val converter = if (needsConversion) {
+ CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row]
} else {
- while (iterator.hasNext) {
- val row = iterator.next()
- writerContainer.outputWriterForRow(row).write(row)
- }
+ r: InternalRow => r.asInstanceOf[Row]
}
+ while (iterator.hasNext) {
+ val row = converter(iterator.next())
+ writerContainer.outputWriterForRow(row).write(row)
+ }
+
writerContainer.commitTask()
} catch { case cause: Throwable =>
logError("Aborting task.", cause)
@@ -181,8 +190,11 @@ private[sql] case class InsertIntoHadoopFsRelation(
val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name))
val codegenEnabled = df.sqlContext.conf.codegenEnabled
+ // This call shouldn't be put into the `try` block below because it only initializes and
+ // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
+ writerContainer.driverSideSetup()
+
try {
- writerContainer.driverSideSetup()
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
writerContainer.commitJob()
relation.refresh()
@@ -192,33 +204,37 @@ private[sql] case class InsertIntoHadoopFsRelation(
throw new SparkException("Job aborted.", cause)
}
- def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = {
- writerContainer.executorSideSetup(taskContext)
+ def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
+ // If anything below fails, we should abort the task.
+ try {
+ writerContainer.executorSideSetup(taskContext)
- val partitionProj = newProjection(codegenEnabled, partitionOutput, output)
- val dataProj = newProjection(codegenEnabled, dataOutput, output)
+ val partitionProj = newProjection(codegenEnabled, partitionOutput, output)
+ val dataProj = newProjection(codegenEnabled, dataOutput, output)
- if (needsConversion) {
- val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
- while (iterator.hasNext) {
- val row = iterator.next()
- val partitionPart = partitionProj(row)
- val dataPart = dataProj(row)
- val convertedDataPart = converter(dataPart).asInstanceOf[Row]
- writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart)
+ val dataConverter: InternalRow => Row = if (needsConversion) {
+ CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row]
+ } else {
+ r: InternalRow => r.asInstanceOf[Row]
}
- } else {
val partitionSchema = StructType.fromAttributes(partitionOutput)
- val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema)
+ val partConverter: InternalRow => Row =
+ CatalystTypeConverters.createToScalaConverter(partitionSchema)
+ .asInstanceOf[InternalRow => Row]
+
while (iterator.hasNext) {
val row = iterator.next()
- val partitionPart = converter(partitionProj(row)).asInstanceOf[Row]
- val dataPart = dataProj(row)
+ val partitionPart = partConverter(partitionProj(row))
+ val dataPart = dataConverter(dataProj(row))
writerContainer.outputWriterForRow(partitionPart).write(dataPart)
}
- }
- writerContainer.commitTask()
+ writerContainer.commitTask()
+ } catch { case cause: Throwable =>
+ logError("Aborting task.", cause)
+ writerContainer.abortTask()
+ throw new SparkException("Task failed while writing rows.", cause)
+ }
}
}
@@ -229,7 +245,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
- if (codegenEnabled) {
+ if (codegenEnabled && expressions.forall(_.isThreadSafe)) {
GenerateProjection.generate(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
@@ -272,8 +288,17 @@ private[sql] abstract class BaseWriterContainer(
def driverSideSetup(): Unit = {
setupIDs(0, 0, 0)
setupConf()
- taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
+
+ // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor
+ // clones the Configuration object passed in. If we initialize the TaskAttemptContext first,
+ // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext.
+ //
+ // Also, the `prepareJobForWrite` call must happen before initializing output format and output
+ // committer, since their initialization involve the job configuration, which can be potentially
+ // decorated in `prepareJobForWrite`.
outputWriterFactory = relation.prepareJobForWrite(job)
+ taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
+
outputFormatClass = job.getOutputFormatClass
outputCommitter = newOutputCommitter(taskAttemptContext)
outputCommitter.setupJob(jobContext)
@@ -301,6 +326,8 @@ private[sql] abstract class BaseWriterContainer(
SQLConf.OUTPUT_COMMITTER_CLASS, null, classOf[OutputCommitter])
Option(committerClass).map { clazz =>
+ logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")
+
// Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
// has an associated output committer. To override this output committer,
// we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
@@ -320,7 +347,9 @@ private[sql] abstract class BaseWriterContainer(
}.getOrElse {
// If output committer class is not set, we will use the one associated with the
// file output format.
- outputFormatClass.newInstance().getOutputCommitter(context)
+ val outputCommitter = outputFormatClass.newInstance().getOutputCommitter(context)
+ logInfo(s"Using output committer class ${outputCommitter.getClass.getCanonicalName}")
+ outputCommitter
}
}
@@ -349,7 +378,9 @@ private[sql] abstract class BaseWriterContainer(
}
def abortTask(): Unit = {
- outputCommitter.abortTask(taskAttemptContext)
+ if (outputCommitter != null) {
+ outputCommitter.abortTask(taskAttemptContext)
+ }
logError(s"Task attempt $taskAttemptId aborted.")
}
@@ -359,7 +390,9 @@ private[sql] abstract class BaseWriterContainer(
}
def abortJob(): Unit = {
- outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
+ if (outputCommitter != null) {
+ outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
+ }
logError(s"Job $jobId aborted.")
}
}
@@ -380,6 +413,7 @@ private[sql] class DefaultWriterContainer(
override def commitTask(): Unit = {
try {
+ assert(writer != null, "OutputWriter instance should have been initialized")
writer.close()
super.commitTask()
} catch {
@@ -391,7 +425,9 @@ private[sql] class DefaultWriterContainer(
override def abortTask(): Unit = {
try {
- writer.close()
+ if (writer != null) {
+ writer.close()
+ }
} finally {
super.abortTask()
}
@@ -435,6 +471,7 @@ private[sql] class DynamicPartitionWriterContainer(
override def commitTask(): Unit = {
try {
outputWriters.values.foreach(_.close())
+ outputWriters.clear()
super.commitTask()
} catch { case cause: Throwable =>
super.abortTask()
@@ -445,6 +482,7 @@ private[sql] class DynamicPartitionWriterContainer(
override def abortTask(): Unit = {
try {
outputWriters.values.foreach(_.close())
+ outputWriters.clear()
} finally {
super.abortTask()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 22587f5a1c6f..b7095c8ead79 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -25,8 +25,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InternalRow}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types._
@@ -166,10 +166,14 @@ private[sql] class DDLParser(
}
)
- protected lazy val optionName: Parser[String] = "[_a-zA-Z][a-zA-Z0-9]*".r ^^ {
+ protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ {
case name => name
}
+ protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ {
+ case parts => parts.mkString(".")
+ }
+
protected lazy val pair: Parser[(String, String)] =
optionName ~ stringLit ^^ { case k ~ v => (k, v) }
@@ -322,19 +326,13 @@ private[sql] object ResolvedDataSource {
Some(partitionColumnsSchema(data.schema, partitionColumns)),
caseInsensitiveOptions)
- // For partitioned relation r, r.schema's column ordering is different with the column
- // ordering of data.logicalPlan. We need a Project to adjust the ordering.
- // So, inside InsertIntoHadoopFsRelation, we can safely apply the schema of r.schema to
- // the data.
- val project =
- Project(
- r.schema.map(field => new UnresolvedAttribute(Seq(field.name))),
- data.logicalPlan)
-
+ // For partitioned relation r, r.schema's column ordering can be different from the column
+ // ordering of data.logicalPlan (partition columns are all moved after data column). This
+ // will be adjusted within InsertIntoHadoopFsRelation.
sqlContext.executePlan(
InsertIntoHadoopFsRelation(
r,
- project,
+ data.logicalPlan,
mode)).toRdd
r
case _ =>
@@ -410,7 +408,7 @@ private[sql] case class CreateTempTableUsing(
provider: String,
options: Map[String, String]) extends RunnableCommand {
- def run(sqlContext: SQLContext): Seq[Row] = {
+ def run(sqlContext: SQLContext): Seq[InternalRow] = {
val resolved = ResolvedDataSource(
sqlContext, userSpecifiedSchema, Array.empty[String], provider, options)
sqlContext.registerDataFrameAsTable(
@@ -427,7 +425,7 @@ private[sql] case class CreateTempTableUsingAsSelect(
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
val df = DataFrame(sqlContext, query)
val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df)
sqlContext.registerDataFrameAsTable(
@@ -440,7 +438,7 @@ private[sql] case class CreateTempTableUsingAsSelect(
private[sql] case class RefreshTable(databaseName: String, tableName: String)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
// Refresh the given table's metadata first.
sqlContext.catalog.refreshTable(databaseName, tableName)
@@ -459,7 +457,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String)
sqlContext.cacheManager.cacheQuery(df, Some(tableName))
}
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index c06026e042d9..43d3507d7d2b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -28,7 +28,8 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.SerializableWritable
-import org.apache.spark.sql.{Row, _}
+import org.apache.spark.sql.execution.RDDConversions
+import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.types.StructType
@@ -93,7 +94,7 @@ trait SchemaRelationProvider {
}
/**
- * ::DeveloperApi::
+ * ::Experimental::
* Implemented by objects that produce relations for a specific kind of data source
* with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a
* USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined
@@ -115,6 +116,7 @@ trait SchemaRelationProvider {
*
* @since 1.4.0
*/
+@Experimental
trait HadoopFsRelationProvider {
/**
* Returns a new base relation with the given parameters, a user defined schema, and a list of
@@ -194,6 +196,8 @@ abstract class BaseRelation {
* java.lang.String -> UTF8String
* java.lang.Decimal -> Decimal
*
+ * If `needConversion` is `false`, buildScan() should return an [[RDD]] of [[InternalRow]]
+ *
* Note: The internal representation is not stable across releases and thus data sources outside
* of Spark SQL should leave this as true.
*
@@ -378,10 +382,10 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]]
def refresh(): Unit = {
- // We don't filter files/directories whose name start with "_" or "." here, as specific data
- // sources may take advantages over them (e.g. Parquet _metadata and _common_metadata files).
- // But "_temporary" directories are explicitly ignored since failed tasks/jobs may leave
- // partial/corrupted data files there.
+ // We don't filter files/directories whose name start with "_" except "_temporary" here, as
+ // specific data sources may take advantages over them (e.g. Parquet _metadata and
+ // _common_metadata files). "_temporary" directories are explicitly ignored since failed
+ // tasks/jobs may leave partial/corrupted data files there.
def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = {
if (status.getPath.getName.toLowerCase == "_temporary") {
Set.empty
@@ -399,6 +403,9 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
val fs = hdfsPath.getFileSystem(hadoopConf)
val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _))
+ }.filterNot { status =>
+ // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories
+ status.getPath.getName.startsWith(".")
}
val files = statuses.filterNot(_.isDir)
@@ -431,14 +438,15 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
// partition values.
userDefinedPartitionColumns.map { partitionSchema =>
val spec = discoverPartitions()
+ val partitionColumnTypes = spec.partitionColumns.map(_.dataType)
val castedPartitions = spec.partitions.map { case p @ Partition(values, path) =>
- val literals = values.toSeq.zip(spec.partitionColumns.map(_.dataType)).map {
+ val literals = values.toSeq.zip(partitionColumnTypes).map {
case (value, dataType) => Literal.create(value, dataType)
}
val castedValues = partitionSchema.zip(literals).map { case (field, literal) =>
Cast(literal, field.dataType).eval()
}
- p.copy(values = Row.fromSeq(castedValues))
+ p.copy(values = InternalRow.fromSeq(castedValues))
}
PartitionSpec(partitionSchema, castedPartitions)
}
@@ -486,9 +494,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
}
private def discoverPartitions(): PartitionSpec = {
+ val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled()
// We use leaf dirs containing data files to discover the schema.
val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq
- PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME)
+ PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME,
+ typeInference)
}
/**
@@ -499,7 +509,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
*/
override lazy val schema: StructType = {
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
- StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column =>
+ StructType(dataSchema ++ partitionColumns.filterNot { column =>
dataSchemaColumnNames.contains(column.name.toLowerCase)
})
}
@@ -566,21 +576,28 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
// Yeah, to workaround serialization...
val dataSchema = this.dataSchema
val codegenEnabled = this.codegenEnabled
+ val needConversion = this.needConversion
val requiredOutput = requiredColumns.map { col =>
val field = dataSchema(col)
BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable)
}.toSeq
- buildScan(inputFiles).mapPartitions { rows =>
- val buildProjection = if (codegenEnabled) {
+ val rdd = buildScan(inputFiles)
+ val converted =
+ if (needConversion) {
+ RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType))
+ } else {
+ rdd.map(_.asInstanceOf[InternalRow])
+ }
+ converted.mapPartitions { rows =>
+ val buildProjection = if (codegenEnabled && requiredOutput.forall(_.isThreadSafe)) {
GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes)
} else {
() => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes)
}
-
val mutableProjection = buildProjection()
- rows.map(mutableProjection)
+ rows.map(r => mutableProjection(r).asInstanceOf[Row])
}
}
diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties
index 28e90b9520b2..12fb128149d3 100644
--- a/sql/core/src/test/resources/log4j.properties
+++ b/sql/core/src/test/resources/log4j.properties
@@ -36,11 +36,11 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n
log4j.appender.FA.Threshold = INFO
# Some packages are noisy for no good reason.
-log4j.additivity.parquet.hadoop.ParquetRecordReader=false
-log4j.logger.parquet.hadoop.ParquetRecordReader=OFF
+log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false
+log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF
-log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false
-log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF
+log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false
+log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF
log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false
log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF
@@ -52,5 +52,5 @@ log4j.additivity.hive.ql.metadata.Hive=false
log4j.logger.hive.ql.metadata.Hive=OFF
# Parquet related logging
-log4j.logger.parquet.hadoop=WARN
+log4j.logger.org.apache.parquet.hadoop=WARN
log4j.logger.org.apache.spark.sql.parquet=INFO
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 0772e5e18742..eb3e91332206 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -25,17 +25,19 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.storage.{RDDBlockId, StorageLevel}
+import org.apache.spark.storage.{StorageLevel, RDDBlockId}
case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+ import ctx.sql
+
def rddIdOf(tableName: String): Int = {
- val executedPlan = table(tableName).queryExecution.executedPlan
+ val executedPlan = ctx.table(tableName).queryExecution.executedPlan
executedPlan.collect {
case InMemoryColumnarTableScan(_, _, relation) =>
relation.cachedColumnBuffers.id
@@ -45,151 +47,151 @@ class CachedTableSuite extends QueryTest {
}
def isMaterialized(rddId: Int): Boolean = {
- sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty
+ ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty
}
test("cache temp table") {
testData.select('key).registerTempTable("tempTable")
assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0)
- cacheTable("tempTable")
+ ctx.cacheTable("tempTable")
assertCached(sql("SELECT COUNT(*) FROM tempTable"))
- uncacheTable("tempTable")
+ ctx.uncacheTable("tempTable")
}
test("unpersist an uncached table will not raise exception") {
- assert(None == cacheManager.lookupCachedData(testData))
+ assert(None == ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == cacheManager.lookupCachedData(testData))
+ assert(None == ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == cacheManager.lookupCachedData(testData))
+ assert(None == ctx.cacheManager.lookupCachedData(testData))
testData.persist()
- assert(None != cacheManager.lookupCachedData(testData))
+ assert(None != ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == cacheManager.lookupCachedData(testData))
+ assert(None == ctx.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == cacheManager.lookupCachedData(testData))
+ assert(None == ctx.cacheManager.lookupCachedData(testData))
}
test("cache table as select") {
sql("CACHE TABLE tempTable AS SELECT key FROM testData")
assertCached(sql("SELECT COUNT(*) FROM tempTable"))
- uncacheTable("tempTable")
+ ctx.uncacheTable("tempTable")
}
test("uncaching temp table") {
testData.select('key).registerTempTable("tempTable1")
testData.select('key).registerTempTable("tempTable2")
- cacheTable("tempTable1")
+ ctx.cacheTable("tempTable1")
assertCached(sql("SELECT COUNT(*) FROM tempTable1"))
assertCached(sql("SELECT COUNT(*) FROM tempTable2"))
// Is this valid?
- uncacheTable("tempTable2")
+ ctx.uncacheTable("tempTable2")
// Should this be cached?
assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0)
}
test("too big for memory") {
- val data = "*" * 10000
- sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF()
+ val data = "*" * 1000
+ ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF()
.registerTempTable("bigData")
- table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
- assert(table("bigData").count() === 200000L)
- table("bigData").unpersist(blocking = true)
+ ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
+ assert(ctx.table("bigData").count() === 200000L)
+ ctx.table("bigData").unpersist(blocking = true)
}
test("calling .cache() should use in-memory columnar caching") {
- table("testData").cache()
- assertCached(table("testData"))
- table("testData").unpersist(blocking = true)
+ ctx.table("testData").cache()
+ assertCached(ctx.table("testData"))
+ ctx.table("testData").unpersist(blocking = true)
}
test("calling .unpersist() should drop in-memory columnar cache") {
- table("testData").cache()
- table("testData").count()
- table("testData").unpersist(blocking = true)
- assertCached(table("testData"), 0)
+ ctx.table("testData").cache()
+ ctx.table("testData").count()
+ ctx.table("testData").unpersist(blocking = true)
+ assertCached(ctx.table("testData"), 0)
}
test("isCached") {
- cacheTable("testData")
+ ctx.cacheTable("testData")
- assertCached(table("testData"))
- assert(table("testData").queryExecution.withCachedData match {
+ assertCached(ctx.table("testData"))
+ assert(ctx.table("testData").queryExecution.withCachedData match {
case _: InMemoryRelation => true
case _ => false
})
- uncacheTable("testData")
- assert(!isCached("testData"))
- assert(table("testData").queryExecution.withCachedData match {
+ ctx.uncacheTable("testData")
+ assert(!ctx.isCached("testData"))
+ assert(ctx.table("testData").queryExecution.withCachedData match {
case _: InMemoryRelation => false
case _ => true
})
}
test("SPARK-1669: cacheTable should be idempotent") {
- assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
+ assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
- cacheTable("testData")
- assertCached(table("testData"))
+ ctx.cacheTable("testData")
+ assertCached(ctx.table("testData"))
assertResult(1, "InMemoryRelation not found, testData should have been cached") {
- table("testData").queryExecution.withCachedData.collect {
+ ctx.table("testData").queryExecution.withCachedData.collect {
case r: InMemoryRelation => r
}.size
}
- cacheTable("testData")
+ ctx.cacheTable("testData")
assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") {
- table("testData").queryExecution.withCachedData.collect {
+ ctx.table("testData").queryExecution.withCachedData.collect {
case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r
}.size
}
- uncacheTable("testData")
+ ctx.uncacheTable("testData")
}
test("read from cached table and uncache") {
- cacheTable("testData")
- checkAnswer(table("testData"), testData.collect().toSeq)
- assertCached(table("testData"))
+ ctx.cacheTable("testData")
+ checkAnswer(ctx.table("testData"), testData.collect().toSeq)
+ assertCached(ctx.table("testData"))
- uncacheTable("testData")
- checkAnswer(table("testData"), testData.collect().toSeq)
- assertCached(table("testData"), 0)
+ ctx.uncacheTable("testData")
+ checkAnswer(ctx.table("testData"), testData.collect().toSeq)
+ assertCached(ctx.table("testData"), 0)
}
test("correct error on uncache of non-cached table") {
intercept[IllegalArgumentException] {
- uncacheTable("testData")
+ ctx.uncacheTable("testData")
}
}
test("SELECT star from cached table") {
sql("SELECT * FROM testData").registerTempTable("selectStar")
- cacheTable("selectStar")
+ ctx.cacheTable("selectStar")
checkAnswer(
sql("SELECT * FROM selectStar WHERE key = 1"),
Seq(Row(1, "1")))
- uncacheTable("selectStar")
+ ctx.uncacheTable("selectStar")
}
test("Self-join cached") {
val unCachedAnswer =
sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
- cacheTable("testData")
+ ctx.cacheTable("testData")
checkAnswer(
sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
unCachedAnswer.toSeq)
- uncacheTable("testData")
+ ctx.uncacheTable("testData")
}
test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
sql("CACHE TABLE testData")
- assertCached(table("testData"))
+ assertCached(ctx.table("testData"))
val rddId = rddIdOf("testData")
assert(
@@ -197,7 +199,7 @@ class CachedTableSuite extends QueryTest {
"Eagerly cached in-memory table should have already been materialized")
sql("UNCACHE TABLE testData")
- assert(!isCached("testData"), "Table 'testData' should not be cached")
+ assert(!ctx.isCached("testData"), "Table 'testData' should not be cached")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
@@ -206,14 +208,14 @@ class CachedTableSuite extends QueryTest {
test("CACHE TABLE tableName AS SELECT * FROM anotherTable") {
sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
- assertCached(table("testCacheTable"))
+ assertCached(ctx.table("testCacheTable"))
val rddId = rddIdOf("testCacheTable")
assert(
isMaterialized(rddId),
"Eagerly cached in-memory table should have already been materialized")
- uncacheTable("testCacheTable")
+ ctx.uncacheTable("testCacheTable")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -221,14 +223,14 @@ class CachedTableSuite extends QueryTest {
test("CACHE TABLE tableName AS SELECT ...") {
sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10")
- assertCached(table("testCacheTable"))
+ assertCached(ctx.table("testCacheTable"))
val rddId = rddIdOf("testCacheTable")
assert(
isMaterialized(rddId),
"Eagerly cached in-memory table should have already been materialized")
- uncacheTable("testCacheTable")
+ ctx.uncacheTable("testCacheTable")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -236,7 +238,7 @@ class CachedTableSuite extends QueryTest {
test("CACHE LAZY TABLE tableName") {
sql("CACHE LAZY TABLE testData")
- assertCached(table("testData"))
+ assertCached(ctx.table("testData"))
val rddId = rddIdOf("testData")
assert(
@@ -248,7 +250,7 @@ class CachedTableSuite extends QueryTest {
isMaterialized(rddId),
"Lazily cached in-memory table should have been materialized")
- uncacheTable("testData")
+ ctx.uncacheTable("testData")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -256,7 +258,7 @@ class CachedTableSuite extends QueryTest {
test("InMemoryRelation statistics") {
sql("CACHE TABLE testData")
- table("testData").queryExecution.withCachedData.collect {
+ ctx.table("testData").queryExecution.withCachedData.collect {
case cached: InMemoryRelation =>
val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum
assert(cached.statistics.sizeInBytes === actualSizeInBytes)
@@ -265,38 +267,38 @@ class CachedTableSuite extends QueryTest {
test("Drops temporary table") {
testData.select('key).registerTempTable("t1")
- table("t1")
- dropTempTable("t1")
- assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found"))
+ ctx.table("t1")
+ ctx.dropTempTable("t1")
+ assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found"))
}
test("Drops cached temporary table") {
testData.select('key).registerTempTable("t1")
testData.select('key).registerTempTable("t2")
- cacheTable("t1")
+ ctx.cacheTable("t1")
- assert(isCached("t1"))
- assert(isCached("t2"))
+ assert(ctx.isCached("t1"))
+ assert(ctx.isCached("t2"))
- dropTempTable("t1")
- assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found"))
- assert(!isCached("t2"))
+ ctx.dropTempTable("t1")
+ assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found"))
+ assert(!ctx.isCached("t2"))
}
test("Clear all cache") {
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- cacheTable("t1")
- cacheTable("t2")
- clearCache()
- assert(cacheManager.isEmpty)
+ ctx.cacheTable("t1")
+ ctx.cacheTable("t2")
+ ctx.clearCache()
+ assert(ctx.cacheManager.isEmpty)
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- cacheTable("t1")
- cacheTable("t2")
+ ctx.cacheTable("t1")
+ ctx.cacheTable("t2")
sql("Clear CACHE")
- assert(cacheManager.isEmpty)
+ assert(ctx.cacheManager.isEmpty)
}
test("Clear accumulators when uncacheTable to prevent memory leaking") {
@@ -305,8 +307,8 @@ class CachedTableSuite extends QueryTest {
Accumulators.synchronized {
val accsSize = Accumulators.originals.size
- cacheTable("t1")
- cacheTable("t2")
+ ctx.cacheTable("t1")
+ ctx.cacheTable("t2")
assert((accsSize + 2) == Accumulators.originals.size)
}
@@ -317,8 +319,8 @@ class CachedTableSuite extends QueryTest {
Accumulators.synchronized {
val accsSize = Accumulators.originals.size
- uncacheTable("t1")
- uncacheTable("t2")
+ ctx.uncacheTable("t1")
+ ctx.uncacheTable("t2")
assert((accsSize - 2) == Accumulators.originals.size)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index d006b83fc075..5a08578e7ba4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -19,14 +19,22 @@ package org.apache.spark.sql
import org.scalatest.Matchers._
+import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
class ColumnExpressionSuite extends QueryTest {
import org.apache.spark.sql.TestData._
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
+ test("alias") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ assert(df.select(df("a").as("b")).columns.head === "b")
+ assert(df.select(df("a").alias("b")).columns.head === "b")
+ }
+
test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
@@ -177,12 +185,20 @@ class ColumnExpressionSuite extends QueryTest {
checkAnswer(
nullStrings.toDF.where($"s".isNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
+
+ checkAnswer(
+ ctx.sql("select isnull(null), isnull(1)"),
+ Row(true, false))
}
test("isNotNull") {
checkAnswer(
nullStrings.toDF.where($"s".isNotNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
+
+ checkAnswer(
+ ctx.sql("select isnotnull(null), isnotnull('a')"),
+ Row(false, true))
}
test("===") {
@@ -206,7 +222,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("!==") {
- val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
+ val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize(
Row(1, 1) ::
Row(1, 2) ::
Row(1, null) ::
@@ -267,7 +283,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("between") {
- val testData = TestSQLContext.sparkContext.parallelize(
+ val testData = ctx.sparkContext.parallelize(
(0, 1, 2) ::
(1, 2, 3) ::
(2, 1, 0) ::
@@ -280,7 +296,7 @@ class ColumnExpressionSuite extends QueryTest {
checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer)
}
- val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
+ val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
@@ -353,23 +369,6 @@ class ColumnExpressionSuite extends QueryTest {
)
}
- test("abs") {
- checkAnswer(
- testData.select(abs('key)).orderBy('key.asc),
- (1 to 100).map(n => Row(n))
- )
-
- checkAnswer(
- negativeData.select(abs('key)).orderBy('key.desc),
- (1 to 100).map(n => Row(n))
- )
-
- checkAnswer(
- testData.select(abs(lit(null))),
- (1 to 100).map(_ => Row(null))
- )
- }
-
test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
@@ -385,6 +384,10 @@ class ColumnExpressionSuite extends QueryTest {
testData.select(upper(lit(null))),
(1 to 100).map(n => Row(null))
)
+
+ checkAnswer(
+ ctx.sql("SELECT upper('aB'), ucase('cDe')"),
+ Row("AB", "CDE"))
}
test("lower") {
@@ -402,11 +405,15 @@ class ColumnExpressionSuite extends QueryTest {
testData.select(lower(lit(null))),
(1 to 100).map(n => Row(null))
)
+
+ checkAnswer(
+ ctx.sql("SELECT lower('aB'), lcase('cDe')"),
+ Row("ab", "cde"))
}
test("monotonicallyIncreasingId") {
// Make sure we have 2 partitions, each with 2 records.
- val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter =>
+ val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter =>
Iterator(Tuple1(1), Tuple1(2))
}.toDF("a")
checkAnswer(
@@ -416,7 +423,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("sparkPartitionId") {
- val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b")
+ val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b")
checkAnswer(
df.select(sparkPartitionId()),
Row(0)
@@ -446,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest {
}
test("rand") {
- val randCol = testData.select('key, rand(5L).as("rand"))
+ val randCol = testData.select($"key", rand(5L).as("rand"))
randCol.columns.length should be (2)
val rows = randCol.collect()
rows.foreach { row =>
assert(row.getDouble(1) <= 1.0)
assert(row.getDouble(1) >= 0.0)
}
+
+ def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = {
+ val projects = df.queryExecution.executedPlan.collect {
+ case project: Project => project
+ }
+ assert(projects.size === expectedNumProjects)
+ }
+
+ // We first create a plan with two Projects.
+ // Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // Project [key, (Rand 5 + 1) AS rand]
+ // LogicalRDD [key, value]
+ // Because Rand function is not deterministic, the column rand is not deterministic.
+ // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // and Project [key, Rand 5 AS rand]. The final plan still has two Projects.
+ val dfWithTwoProjects =
+ testData
+ .select($"key", (rand(5L) + 1).as("rand"))
+ .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2"))
+ checkNumProjects(dfWithTwoProjects, 2)
+
+ // Now, we add one more project rand1 - rand2 on top of the query plan.
+ // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated
+ // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2.
+ // So, the plan will be optimized from ...
+ // Project [(rand1 - rand2) AS (rand1 - rand2)]
+ // Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // Project [key, (Rand 5 + 1) AS rand]
+ // LogicalRDD [key, value]
+ // to ...
+ // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)]
+ // Project [key, Rand 5 AS rand]
+ // LogicalRDD [key, value]
+ val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2")
+ checkNumProjects(dfWithThreeProjects, 2)
+ dfWithThreeProjects.collect().foreach { row =>
+ assert(row.getDouble(0) === 2.0 +- 0.0001)
+ }
}
test("randn") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 232f05c00918..790b405c7269 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.DecimalType
class DataFrameAggregateSuite extends QueryTest {
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
@@ -67,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest {
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
- TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false")
+ ctx.conf.setConf("spark.sql.retainGroupColumns", "false")
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(3), Row(3), Row(3))
)
- TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true")
+ ctx.conf.setConf("spark.sql.retainGroupColumns", "true")
}
test("agg without groups") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index b1e0faa310b6..cfd23867a9bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
/**
@@ -27,6 +26,9 @@ import org.apache.spark.sql.types._
*/
class DataFrameFunctionsSuite extends QueryTest {
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
test("array with column name") {
val df = Seq((0, 1)).toDF("a", "b")
val row = df.select(array("a", "b")).first()
@@ -83,9 +85,59 @@ class DataFrameFunctionsSuite extends QueryTest {
}
}
+ test("constant functions") {
+ checkAnswer(
+ testData2.select(e()).limit(1),
+ Row(scala.math.E)
+ )
+ checkAnswer(
+ testData2.select(pi()).limit(1),
+ Row(scala.math.Pi)
+ )
+ checkAnswer(
+ ctx.sql("SELECT E()"),
+ Row(scala.math.E)
+ )
+ checkAnswer(
+ ctx.sql("SELECT PI()"),
+ Row(scala.math.Pi)
+ )
+ }
+
test("bitwiseNOT") {
checkAnswer(
testData2.select(bitwiseNOT($"a")),
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
}
+
+ test("if function") {
+ val df = Seq((1, 2)).toDF("a", "b")
+ checkAnswer(
+ df.selectExpr("if(a = 1, 'one', 'not_one')", "if(b = 1, 'one', 'not_one')"),
+ Row("one", "not_one"))
+ }
+
+ test("nvl function") {
+ checkAnswer(
+ ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
+ Row("x", "y", null))
+ }
+
+ test("string length function") {
+ checkAnswer(
+ nullStrings.select(strlen($"s"), strlen("s")),
+ nullStrings.collect().toSeq.map { r =>
+ val v = r.getString(1)
+ val l = if (v == null) null else v.length
+ Row(l, l)
+ })
+
+ checkAnswer(
+ nullStrings.selectExpr("length(s)"),
+ nullStrings.collect().toSeq.map { r =>
+ val v = r.getString(1)
+ val l = if (v == null) null else v.length
+ Row(l)
+ })
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
index 2d2367d6e729..fbb30706a494 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.sql
-import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc}
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-
-
class DataFrameImplicitsSuite extends QueryTest {
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
test("RDD of tuples") {
checkAnswer(
- sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
+ ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
(1 to 10).map(i => Row(i, i.toString)))
}
@@ -37,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest {
test("RDD[Int]") {
checkAnswer(
- sc.parallelize(1 to 10).toDF("intCol"),
+ ctx.sparkContext.parallelize(1 to 10).toDF("intCol"),
(1 to 10).map(i => Row(i)))
}
test("RDD[Long]") {
checkAnswer(
- sc.parallelize(1L to 10L).toDF("longCol"),
+ ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"),
(1L to 10L).map(i => Row(i)))
}
test("RDD[String]") {
checkAnswer(
- sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
+ ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
(1 to 10).map(i => Row(i.toString)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 787f3f175fea..6165764632c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-
class DataFrameJoinSuite extends QueryTest {
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
test("join - join using") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str")
@@ -34,6 +34,15 @@ class DataFrameJoinSuite extends QueryTest {
Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil)
}
+ test("join - join using multiple columns") {
+ val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str")
+ val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str")
+
+ checkAnswer(
+ df.join(df2, Seq("int", "int2")),
+ Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
+ }
+
test("join - join using self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
@@ -49,7 +58,8 @@ class DataFrameJoinSuite extends QueryTest {
checkAnswer(
df1.join(df2, $"df1.key" === $"df2.key"),
- sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
+ ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
+ .collect().toSeq)
}
test("join - using aliases after self join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 41b4f02e6a29..495701d4f616 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql
import scala.collection.JavaConversions._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-
class DataFrameNaFunctionsSuite extends QueryTest {
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
def createDF(): DataFrame = {
Seq[(String, java.lang.Integer, java.lang.Double)](
("Bob", 16, 176.5),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 46b1845a9180..0d3ff899dad7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -17,16 +17,16 @@
package org.apache.spark.sql
-import org.scalatest.FunSuite
import org.scalatest.Matchers._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.SparkFunSuite
-class DataFrameStatSuite extends FunSuite {
-
- val sqlCtx = TestSQLContext
- def toLetter(i: Int): String = (i + 97).toChar.toString
+class DataFrameStatSuite extends SparkFunSuite {
+
+ private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
+ import sqlCtx.implicits._
+
+ private def toLetter(i: Int): String = (i + 97).toChar.toString
test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
@@ -74,10 +74,10 @@ class DataFrameStatSuite extends FunSuite {
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
assert(rows(0).get(0).toString === "0")
assert(rows(0).getLong(1) === 2L)
- assert(rows(0).get(2) === null)
+ assert(rows(0).get(2) === 0L)
assert(rows(1).get(0).toString === "1")
assert(rows(1).getLong(1) === 1L)
- assert(rows(1).get(2) === null)
+ assert(rows(1).get(2) === 0L)
assert(rows(2).get(0).toString === "2")
assert(rows(2).getLong(1) === 2L)
assert(rows(2).getLong(2) === 1L)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a4fd1058afce..fa98e23e3d14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -21,17 +21,19 @@ import scala.language.postfixOps
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
-import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint}
class DataFrameSuite extends QueryTest {
import org.apache.spark.sql.TestData._
+ lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
test("analysis error should be eagerly reported") {
- val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis
+ val oldSetting = ctx.conf.dataFrameEagerAnalysis
// Eager analysis.
- TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true")
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true")
intercept[Exception] { testData.select('nonExistentName) }
intercept[Exception] {
@@ -45,11 +47,11 @@ class DataFrameSuite extends QueryTest {
}
// No more eager analysis once the flag is turned off
- TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false")
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false")
testData.select('nonExistentName)
// Set the flag back to original value before this test.
- TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
}
test("dataframe toString") {
@@ -67,12 +69,12 @@ class DataFrameSuite extends QueryTest {
}
test("invalid plan toString, debug mode") {
- val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis
- TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true")
+ val oldSetting = ctx.conf.dataFrameEagerAnalysis
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true")
// Turn on debug mode so we can see invalid query plans.
import org.apache.spark.sql.execution.debug._
- TestSQLContext.debug()
+ ctx.debug()
val badPlan = testData.select('badColumn)
@@ -81,7 +83,7 @@ class DataFrameSuite extends QueryTest {
badPlan.toString)
// Set the flag back to original value before this test.
- TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
}
test("access complex data") {
@@ -97,8 +99,8 @@ class DataFrameSuite extends QueryTest {
}
test("empty data frame") {
- assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
- assert(TestSQLContext.emptyDataFrame.count() === 0)
+ assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String])
+ assert(ctx.emptyDataFrame.count() === 0)
}
test("head and take") {
@@ -132,6 +134,14 @@ class DataFrameSuite extends QueryTest {
)
}
+ test("explode alias and star") {
+ val df = Seq((Array("a"), 1)).toDF("a", "b")
+
+ checkAnswer(
+ df.select(explode($"a").as("a"), $"*"),
+ Row("a", Seq("a"), 1) :: Nil)
+ }
+
test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),
@@ -311,7 +321,7 @@ class DataFrameSuite extends QueryTest {
}
test("replace column using withColumn") {
- val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
+ val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)
checkAnswer(
df3.select("x"),
@@ -334,6 +344,51 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name) === Seq("key", "value"))
}
+ test("drop column using drop with column reference") {
+ val col = testData("key")
+ val df = testData.drop(col)
+ checkAnswer(
+ df,
+ testData.collect().map(x => Row(x.getString(1))).toSeq)
+ assert(df.schema.map(_.name) === Seq("value"))
+ }
+
+ test("drop unknown column (no-op) with column reference") {
+ val col = Column("random")
+ val df = testData.drop(col)
+ checkAnswer(
+ df,
+ testData.collect().toSeq)
+ assert(df.schema.map(_.name) === Seq("key", "value"))
+ }
+
+ test("drop unknown column with same name (no-op) with column reference") {
+ val col = Column("key")
+ val df = testData.drop(col)
+ checkAnswer(
+ df,
+ testData.collect().toSeq)
+ assert(df.schema.map(_.name) === Seq("key", "value"))
+ }
+
+ test("drop column after join with duplicate columns using column reference") {
+ val newSalary = salary.withColumnRenamed("personId", "id")
+ val col = newSalary("id")
+ // this join will result in duplicate "id" columns
+ val joinedDf = person.join(newSalary,
+ person("id") === newSalary("id"), "inner")
+ // remove only the "id" column that was associated with newSalary
+ val df = joinedDf.drop(col)
+ checkAnswer(
+ df,
+ joinedDf.collect().map {
+ case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) =>
+ Row(id, name, age, salary)
+ }.toSeq)
+ assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary"))
+ assert(df("id") == person("id"))
+ }
+
test("withColumnRenamed") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
.withColumnRenamed("value", "valueRenamed")
@@ -347,7 +402,7 @@ class DataFrameSuite extends QueryTest {
test("randomSplit") {
val n = 600
- val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
+ val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
@@ -422,12 +477,63 @@ class DataFrameSuite extends QueryTest {
testData.select($"*").show(1000)
}
+ test("showString(negative)") {
+ val expectedAnswer = """+---+-----+
+ ||key|value|
+ |+---+-----+
+ |+---+-----+
+ |only showing top 0 rows
+ |""".stripMargin
+ assert(testData.select($"*").showString(-1) === expectedAnswer)
+ }
+
+ test("showString(0)") {
+ val expectedAnswer = """+---+-----+
+ ||key|value|
+ |+---+-----+
+ |+---+-----+
+ |only showing top 0 rows
+ |""".stripMargin
+ assert(testData.select($"*").showString(0) === expectedAnswer)
+ }
+
+ test("showString: array") {
+ val df = Seq(
+ (Array(1, 2, 3), Array(1, 2, 3)),
+ (Array(2, 3, 4), Array(2, 3, 4))
+ ).toDF()
+ val expectedAnswer = """+---------+---------+
+ || _1| _2|
+ |+---------+---------+
+ ||[1, 2, 3]|[1, 2, 3]|
+ ||[2, 3, 4]|[2, 3, 4]|
+ |+---------+---------+
+ |""".stripMargin
+ assert(df.showString(10) === expectedAnswer)
+ }
+
+ test("showString: minimum column width") {
+ val df = Seq(
+ (1, 1),
+ (2, 2)
+ ).toDF()
+ val expectedAnswer = """+---+---+
+ || _1| _2|
+ |+---+---+
+ || 1| 1|
+ || 2| 2|
+ |+---+---+
+ |""".stripMargin
+ assert(df.showString(10) === expectedAnswer)
+ }
+
test("SPARK-7319 showString") {
val expectedAnswer = """+---+-----+
||key|value|
|+---+-----+
|| 1| 1|
|+---+-----+
+ |only showing top 1 row
|""".stripMargin
assert(testData.select($"*").showString(1) === expectedAnswer)
}
@@ -442,19 +548,22 @@ class DataFrameSuite extends QueryTest {
}
test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
- val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
+ val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
- val df = TestSQLContext.createDataFrame(rowRDD, schema)
+ val df = ctx.createDataFrame(rowRDD, schema)
df.rdd.collect()
}
test("SPARK-6899") {
- val originalValue = TestSQLContext.conf.codegenEnabled
- TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
- checkAnswer(
- decimalData.agg(avg('a)),
- Row(new java.math.BigDecimal(2.0)))
- TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ val originalValue = ctx.conf.codegenEnabled
+ ctx.setConf(SQLConf.CODEGEN_ENABLED, "true")
+ try{
+ checkAnswer(
+ decimalData.agg(avg('a)),
+ Row(new java.math.BigDecimal(2.0)))
+ } finally {
+ ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ }
}
test("SPARK-7133: Implement struct, array, and map field accessor") {
@@ -465,14 +574,14 @@ class DataFrameSuite extends QueryTest {
}
test("SPARK-7551: support backticks for DataFrame attribute resolution") {
- val df = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD(
+ val df = ctx.read.json(ctx.sparkContext.makeRDD(
"""{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df.select(df("`a.b`.c.`d..e`.`f`")),
Row(1)
)
- val df2 = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD(
+ val df2 = ctx.read.json(ctx.sparkContext.makeRDD(
"""{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df2.select(df2("`a b`.c.d e.f")),
@@ -492,7 +601,7 @@ class DataFrameSuite extends QueryTest {
}
test("SPARK-7324 dropDuplicates") {
- val testData = TestSQLContext.sparkContext.parallelize(
+ val testData = ctx.sparkContext.parallelize(
(2, 1, 2) :: (1, 1, 1) ::
(1, 2, 1) :: (2, 1, 2) ::
(2, 2, 2) :: (2, 2, 1) ::
@@ -540,41 +649,49 @@ class DataFrameSuite extends QueryTest {
test("SPARK-7150 range api") {
// numSlice is greater than length
- val res1 = TestSQLContext.range(0, 10, 1, 15).select("id")
+ val res1 = ctx.range(0, 10, 1, 15).select("id")
assert(res1.count == 10)
assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
- val res2 = TestSQLContext.range(3, 15, 3, 2).select("id")
+ val res2 = ctx.range(3, 15, 3, 2).select("id")
assert(res2.count == 4)
assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
- val res3 = TestSQLContext.range(1, -2).select("id")
+ val res3 = ctx.range(1, -2).select("id")
assert(res3.count == 0)
// start is positive, end is negative, step is negative
- val res4 = TestSQLContext.range(1, -2, -2, 6).select("id")
+ val res4 = ctx.range(1, -2, -2, 6).select("id")
assert(res4.count == 2)
assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
// start, end, step are negative
- val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id")
+ val res5 = ctx.range(-3, -8, -2, 1).select("id")
assert(res5.count == 3)
assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
// start, end are negative, step is positive
- val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id")
+ val res6 = ctx.range(-8, -4, 2, 1).select("id")
assert(res6.count == 2)
assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
- val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id")
+ val res7 = ctx.range(-10, -9, -20, 1).select("id")
assert(res7.count == 0)
- val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+ val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
assert(res8.count == 3)
assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
- val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+ val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
assert(res9.count == 2)
assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+
+ // only end provided as argument
+ val res10 = ctx.range(10).select("id")
+ assert(res10.count == 10)
+ assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+ val res11 = ctx.range(-1).select("id")
+ assert(res11.count == 0)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 407c78965783..ffd26c4f5a7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -20,27 +20,28 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
class JoinSuite extends QueryTest with BeforeAndAfterEach {
// Ensures tables are loaded.
TestData
+ lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+ import ctx.logicalPlanToSparkQuery
+
test("equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
- val planned = planner.HashJoin(join)
+ val planned = ctx.planner.HashJoin(join)
assert(planned.size === 1)
}
def assertJoin(sqlString: String, c: Class[_]): Any = {
- val df = sql(sqlString)
+ val df = ctx.sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
case j: ShuffledHashJoin => j
@@ -61,9 +62,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("join operator selection") {
- cacheManager.clearCache()
+ ctx.cacheManager.clearCache()
- val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
+ val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
@@ -94,22 +95,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
classOf[BroadcastNestedLoopJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
try {
- conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+ ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true")
Seq(
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
} finally {
- conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
+ ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
}
}
test("broadcasted hash join operator selection") {
- cacheManager.clearCache()
- sql("CACHE TABLE testData")
+ ctx.cacheManager.clearCache()
+ ctx.sql("CACHE TABLE testData")
- val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
+ val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
Seq(
("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]),
@@ -117,7 +118,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
try {
- conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+ ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true")
Seq(
("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key = 2",
@@ -126,17 +127,17 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
} finally {
- conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
+ ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
}
- sql("UNCACHE TABLE testData")
+ ctx.sql("UNCACHE TABLE testData")
}
test("multiple-key equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
- val planned = planner.HashJoin(join)
+ val planned = ctx.planner.HashJoin(join)
assert(planned.size === 1)
}
@@ -241,7 +242,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
// Make sure we are choosing left.outputPartitioning as the
// outputPartitioning for the outer join operator.
checkAnswer(
- sql(
+ ctx.sql(
"""
|SELECT l.N, count(*)
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@@ -255,7 +256,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, 1) :: Nil)
checkAnswer(
- sql(
+ ctx.sql(
"""
|SELECT r.a, count(*)
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@@ -301,7 +302,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
// Make sure we are choosing right.outputPartitioning as the
// outputPartitioning for the outer join operator.
checkAnswer(
- sql(
+ ctx.sql(
"""
|SELECT l.a, count(*)
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -310,7 +311,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, 6))
checkAnswer(
- sql(
+ ctx.sql(
"""
|SELECT r.N, count(*)
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -362,7 +363,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
// Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
checkAnswer(
- sql(
+ ctx.sql(
"""
|SELECT l.a, count(*)
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -371,7 +372,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, 10))
checkAnswer(
- sql(
+ ctx.sql(
"""
|SELECT r.N, count(*)
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -386,7 +387,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, 4) :: Nil)
checkAnswer(
- sql(
+ ctx.sql(
"""
|SELECT l.N, count(*)
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@@ -401,7 +402,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, 4) :: Nil)
checkAnswer(
- sql(
+ ctx.sql(
"""
|SELECT r.a, count(*)
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@@ -411,11 +412,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("broadcasted left semi join operator selection") {
- cacheManager.clearCache()
- sql("CACHE TABLE testData")
- val tmp = conf.autoBroadcastJoinThreshold
+ ctx.cacheManager.clearCache()
+ ctx.sql("CACHE TABLE testData")
+ val tmp = ctx.conf.autoBroadcastJoinThreshold
- sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000")
+ ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000")
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[BroadcastLeftSemiJoinHash])
@@ -423,7 +424,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case (query, joinClass) => assertJoin(query, joinClass)
}
- sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
+ ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
@@ -431,12 +432,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case (query, joinClass) => assertJoin(query, joinClass)
}
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString)
- sql("UNCACHE TABLE testData")
+ ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString)
+ ctx.sql("UNCACHE TABLE testData")
}
test("left semi join") {
- val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
+ val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(df,
Row(1, 1) ::
Row(1, 2) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 3ce97c3fffdb..2089660c52bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -19,49 +19,47 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfter
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
class ListTablesSuite extends QueryTest with BeforeAndAfter {
- import org.apache.spark.sql.test.TestSQLContext.implicits._
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
- val df =
- sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")
+ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")
before {
df.registerTempTable("ListTablesSuiteTable")
}
after {
- catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
}
test("get all tables") {
checkAnswer(
- tables().filter("tableName = 'ListTablesSuiteTable'"),
+ ctx.tables().filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
checkAnswer(
- sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
+ ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
- catalog.unregisterTable(Seq("ListTablesSuiteTable"))
- assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+ ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}
test("getting all Tables with a database name has no impact on returned table names") {
checkAnswer(
- tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
+ ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
checkAnswer(
- sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
+ ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
- catalog.unregisterTable(Seq("ListTablesSuiteTable"))
- assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+ ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}
test("query the returned DataFrame of tables") {
@@ -69,19 +67,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
- Seq(tables(), sql("SHOW TABLes")).foreach {
+ Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
tableDF.registerTempTable("tables")
checkAnswer(
- sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
+ ctx.sql(
+ "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
Row(true, "ListTablesSuiteTable")
)
checkAnswer(
- tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+ ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
- dropTempTable("tables")
+ ctx.dropTempTable("tables")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index c4281c4b55c0..e2daaf6b730c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -17,36 +17,30 @@
package org.apache.spark.sql
-import java.lang.{Double => JavaDouble}
-
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-
-private[this] object MathExpressionsTestData {
-
- case class DoubleData(a: JavaDouble, b: JavaDouble)
- val doubleData = TestSQLContext.sparkContext.parallelize(
- (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF()
-
- val nnDoubleData = TestSQLContext.sparkContext.parallelize(
- (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF()
-
- case class NullDoubles(a: JavaDouble)
- val nullDoubles =
- TestSQLContext.sparkContext.parallelize(
- NullDoubles(1.0) ::
- NullDoubles(2.0) ::
- NullDoubles(3.0) ::
- NullDoubles(null) :: Nil
- ).toDF()
+import org.apache.spark.sql.functions.{log => logarithm}
+
+
+private object MathExpressionsTestData {
+ case class DoubleData(a: java.lang.Double, b: java.lang.Double)
+ case class NullDoubles(a: java.lang.Double)
}
class MathExpressionsSuite extends QueryTest {
import MathExpressionsTestData._
- def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T](
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
+ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF()
+
+ private lazy val nnDoubleData = (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1)).toDF()
+
+ private lazy val nullDoubles =
+ Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF()
+
+ private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T](
c: Column => Column,
f: T => T): Unit = {
checkAnswer(
@@ -65,7 +59,8 @@ class MathExpressionsSuite extends QueryTest {
)
}
- def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = {
+ private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit =
+ {
checkAnswer(
nnDoubleData.select(c('a)),
(1 to 10).map(n => Row(f(n * 0.1)))
@@ -89,7 +84,7 @@ class MathExpressionsSuite extends QueryTest {
)
}
- def testTwoToOneMathFunction(
+ private def testTwoToOneMathFunction(
c: (Column, Column) => Column,
d: (Column, Double) => Column,
f: (Double, Double) => Double): Unit = {
@@ -157,20 +152,31 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction(tanh, math.tanh)
}
- test("toDeg") {
+ test("toDegrees") {
testOneToOneMathFunction(toDegrees, math.toDegrees)
+ checkAnswer(
+ ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
+ Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
+ )
}
- test("toRad") {
+ test("toRadians") {
testOneToOneMathFunction(toRadians, math.toRadians)
+ checkAnswer(
+ ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
+ Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
+ )
}
test("cbrt") {
testOneToOneMathFunction(cbrt, math.cbrt)
}
- test("ceil") {
+ test("ceil and ceiling") {
testOneToOneMathFunction(ceil, math.ceil)
+ checkAnswer(
+ ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
+ Row(0.0, 1.0, 2.0))
}
test("floor") {
@@ -189,12 +195,21 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction(expm1, math.expm1)
}
- test("signum") {
+ test("signum / sign") {
testOneToOneMathFunction[Double](signum, math.signum)
+
+ checkAnswer(
+ ctx.sql("SELECT sign(10), signum(-11)"),
+ Row(1, -1))
}
- test("pow") {
+ test("pow / power") {
testTwoToOneMathFunction(pow, pow, math.pow)
+
+ checkAnswer(
+ ctx.sql("SELECT pow(1, 2), power(2, 1)"),
+ Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
+ )
}
test("hypot") {
@@ -205,8 +220,12 @@ class MathExpressionsSuite extends QueryTest {
testTwoToOneMathFunction(atan2, atan2, math.atan2)
}
- test("log") {
- testOneToOneNonNegativeMathFunction(log, math.log)
+ test("log / ln") {
+ testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
+ checkAnswer(
+ ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
+ Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
+ )
}
test("log10") {
@@ -217,4 +236,37 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
}
+ test("abs") {
+ val input =
+ Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5))
+ checkAnswer(
+ input.toDF("key", "value").select(abs($"key").alias("a")).sort("a"),
+ input.map(pair => Row(pair._2)))
+
+ checkAnswer(
+ input.toDF("key", "value").selectExpr("abs(key) a").sort("a"),
+ input.map(pair => Row(pair._2)))
+ }
+
+ test("log2") {
+ val df = Seq((1, 2)).toDF("a", "b")
+ checkAnswer(
+ df.select(log2("b") + log2("a")),
+ Row(1))
+
+ checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
+ }
+
+ test("negative") {
+ checkAnswer(
+ ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
+ Row(-1, 0, 1))
+ }
+
+ test("positive") {
+ val df = Seq((1, -1, "abc")).toDF("a", "b", "c")
+ checkAnswer(df.selectExpr("positive(a)"), Row(1))
+ checkAnswer(df.selectExpr("positive(b)"), Row(-1))
+ checkAnswer(df.selectExpr("positive(c)"), Row("abc"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index fb3ba4bc1b90..d84b57af9c88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -17,15 +17,16 @@
package org.apache.spark.sql
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
-import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-class RowSuite extends FunSuite {
+class RowSuite extends SparkFunSuite {
+
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
test("create row") {
val expected = new GenericMutableRow(4)
@@ -56,7 +57,7 @@ class RowSuite extends FunSuite {
test("serialize w/ kryo") {
val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first()
- val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf)
+ val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf)
val instance = serializer.newInstance()
val ser = instance.serialize(row)
val de = instance.deserialize(ser).asInstanceOf[Row]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index bf73d0c7074a..76d0dd1744a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -17,68 +17,64 @@
package org.apache.spark.sql
-import org.scalatest.FunSuiteLike
-import org.apache.spark.sql.test._
+class SQLConfSuite extends QueryTest {
-/* Implicits */
-import TestSQLContext._
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-class SQLConfSuite extends QueryTest with FunSuiteLike {
-
- val testKey = "test.key.0"
- val testVal = "test.val.0"
+ private val testKey = "test.key.0"
+ private val testVal = "test.val.0"
test("propagate from spark conf") {
// We create a new context here to avoid order dependence with other tests that might call
// clear().
- val newContext = new SQLContext(TestSQLContext.sparkContext)
- assert(newContext.getConf("spark.sql.testkey", "false") == "true")
+ val newContext = new SQLContext(ctx.sparkContext)
+ assert(newContext.getConf("spark.sql.testkey", "false") === "true")
}
test("programmatic ways of basic setting and getting") {
- conf.clear()
- assert(getAllConfs.size === 0)
+ ctx.conf.clear()
+ assert(ctx.getAllConfs.size === 0)
- setConf(testKey, testVal)
- assert(getConf(testKey) == testVal)
- assert(getConf(testKey, testVal + "_") == testVal)
- assert(getAllConfs.contains(testKey))
+ ctx.setConf(testKey, testVal)
+ assert(ctx.getConf(testKey) === testVal)
+ assert(ctx.getConf(testKey, testVal + "_") === testVal)
+ assert(ctx.getAllConfs.contains(testKey))
// Tests SQLConf as accessed from a SQLContext is mutable after
// the latter is initialized, unlike SparkConf inside a SparkContext.
- assert(TestSQLContext.getConf(testKey) == testVal)
- assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal)
- assert(TestSQLContext.getAllConfs.contains(testKey))
+ assert(ctx.getConf(testKey) == testVal)
+ assert(ctx.getConf(testKey, testVal + "_") === testVal)
+ assert(ctx.getAllConfs.contains(testKey))
- conf.clear()
+ ctx.conf.clear()
}
test("parse SQL set commands") {
- conf.clear()
- sql(s"set $testKey=$testVal")
- assert(getConf(testKey, testVal + "_") == testVal)
- assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal)
+ ctx.conf.clear()
+ ctx.sql(s"set $testKey=$testVal")
+ assert(ctx.getConf(testKey, testVal + "_") === testVal)
+ assert(ctx.getConf(testKey, testVal + "_") === testVal)
- sql("set some.property=20")
- assert(getConf("some.property", "0") == "20")
- sql("set some.property = 40")
- assert(getConf("some.property", "0") == "40")
+ ctx.sql("set some.property=20")
+ assert(ctx.getConf("some.property", "0") === "20")
+ ctx.sql("set some.property = 40")
+ assert(ctx.getConf("some.property", "0") === "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
- sql(s"set $key=$vs")
- assert(getConf(key, "0") == vs)
+ ctx.sql(s"set $key=$vs")
+ assert(ctx.getConf(key, "0") === vs)
- sql(s"set $key=")
- assert(getConf(key, "0") == "")
+ ctx.sql(s"set $key=")
+ assert(ctx.getConf(key, "0") === "")
- conf.clear()
+ ctx.conf.clear()
}
test("deprecated property") {
- conf.clear()
- sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
- assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10")
+ ctx.conf.clear()
+ ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
+ assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index f186bc1c1812..c8d8796568a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -17,33 +17,32 @@
package org.apache.spark.sql
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.SparkFunSuite
-class SQLContextSuite extends FunSuite with BeforeAndAfterAll {
+class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll {
- private val testSqlContext = TestSQLContext
- private val testSparkContext = TestSQLContext.sparkContext
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
override def afterAll(): Unit = {
- SQLContext.setLastInstantiatedContext(testSqlContext)
+ SQLContext.setLastInstantiatedContext(ctx)
}
test("getOrCreate instantiates SQLContext") {
SQLContext.clearLastInstantiatedContext()
- val sqlContext = SQLContext.getOrCreate(testSparkContext)
+ val sqlContext = SQLContext.getOrCreate(ctx.sparkContext)
assert(sqlContext != null, "SQLContext.getOrCreate returned null")
- assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext),
+ assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext),
"SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate")
}
test("getOrCreate gets last explicitly instantiated SQLContext") {
SQLContext.clearLastInstantiatedContext()
- val sqlContext = new SQLContext(testSparkContext)
- assert(SQLContext.getOrCreate(testSparkContext) != null,
+ val sqlContext = new SQLContext(ctx.sparkContext)
+ assert(SQLContext.getOrCreate(ctx.sparkContext) != null,
"SQLContext.getOrCreate after explicitly created SQLContext returned null")
- assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext),
+ assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext),
"SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bf18bf854aa4..30db840166ca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -24,20 +24,36 @@ import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
-
+import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultParserDialect
-class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
+class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Make sure the tables are loaded.
TestData
- import org.apache.spark.sql.test.TestSQLContext.implicits._
- val sqlCtx = TestSQLContext
+ val sqlContext = org.apache.spark.sql.test.TestSQLContext
+ import sqlContext.implicits._
+ import sqlContext.sql
+
+ test("having clause") {
+ Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav")
+ checkAnswer(
+ sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"),
+ Row("one", 6) :: Row("three", 3) :: Nil)
+ }
+
+ test("SPARK-8010: promote numeric to string") {
+ val df = Seq((1, 1)).toDF("key", "value")
+ df.registerTempTable("src")
+ val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ")
+ val queryCoalesce = sql("select coalesce(null, 1, '1') from src ")
+
+ checkAnswer(queryCaseWhen, Row("1.0") :: Nil)
+ checkAnswer(queryCoalesce, Row("1") :: Nil)
+ }
test("SPARK-6743: no columns from cache") {
Seq(
@@ -46,7 +62,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
(43, 81, 24)
).toDF("a", "b", "c").registerTempTable("cachedData")
- cacheTable("cachedData")
+ sqlContext.cacheTable("cachedData")
checkAnswer(
sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
Row(0) :: Row(81) :: Nil)
@@ -94,14 +110,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SQL Dialect Switching to a new SQL parser") {
- val newContext = new SQLContext(TestSQLContext.sparkContext)
+ val newContext = new SQLContext(sqlContext.sparkContext)
newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName())
assert(newContext.getSQLDialect().getClass === classOf[MyDialect])
assert(newContext.sql("SELECT 1").collect() === Array(Row(1)))
}
test("SQL Dialect Switch to an invalid parser with alias") {
- val newContext = new SQLContext(TestSQLContext.sparkContext)
+ val newContext = new SQLContext(sqlContext.sparkContext)
newContext.sql("SET spark.sql.dialect=MyTestClass")
intercept[DialectException] {
newContext.sql("SELECT 1")
@@ -117,8 +133,35 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
)
}
+ test("SPARK-7158 collect and take return different results") {
+ import java.util.UUID
+ import org.apache.spark.sql.types._
+
+ val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
+ // we except the id is materialized once
+ def id: () => String = () => { UUID.randomUUID().toString() }
+
+ val dfWithId = df.withColumn("id", callUDF(id, StringType))
+ // Make a new DataFrame (actually the same reference to the old one)
+ val cached = dfWithId.cache()
+ // Trigger the cache
+ val d0 = dfWithId.collect()
+ val d1 = cached.collect()
+ val d2 = cached.collect()
+
+ // Since the ID is only materialized once, then all of the records
+ // should come from the cache, not by re-computing. Otherwise, the ID
+ // will be different
+ assert(d0.map(_(0)) === d2.map(_(0)))
+ assert(d0.map(_(1)) === d2.map(_(1)))
+
+ assert(d1.map(_(0)) === d2.map(_(0)))
+ assert(d1.map(_(1)) === d2.map(_(1)))
+ }
+
test("grouping on nested fields") {
- read.json(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil))
+ sqlContext.read.json(sqlContext.sparkContext.parallelize(
+ """{"nested": {"attribute": 1}, "value": 2}""" :: Nil))
.registerTempTable("rows")
checkAnswer(
@@ -135,8 +178,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SPARK-6201 IN type conversion") {
- read.json(
- sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}")))
+ sqlContext.read.json(
+ sqlContext.sparkContext.parallelize(
+ Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}")))
.registerTempTable("d")
checkAnswer(
@@ -144,25 +188,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Seq(Row("1"), Row("2")))
}
- test("SPARK-3176 Added Parser of SQL ABS()") {
- checkAnswer(
- sql("SELECT ABS(-1.3)"),
- Row(1.3))
- checkAnswer(
- sql("SELECT ABS(0.0)"),
- Row(0.0))
- checkAnswer(
- sql("SELECT ABS(2.5)"),
- Row(2.5))
- }
-
test("aggregation with codegen") {
- val originalValue = conf.codegenEnabled
- setConf(SQLConf.CODEGEN_ENABLED, "true")
+ val originalValue = sqlContext.conf.codegenEnabled
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
// Prepare a table that we can group some rows.
- table("testData")
- .unionAll(table("testData"))
- .unionAll(table("testData"))
+ sqlContext.table("testData")
+ .unionAll(sqlContext.table("testData"))
+ .unionAll(sqlContext.table("testData"))
.registerTempTable("testData3x")
def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
@@ -184,77 +216,79 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(df, expectedResults)
}
- // Just to group rows.
- testCodeGen(
- "SELECT key FROM testData3x GROUP BY key",
- (1 to 100).map(Row(_)))
- // COUNT
- testCodeGen(
- "SELECT key, count(value) FROM testData3x GROUP BY key",
- (1 to 100).map(i => Row(i, 3)))
- testCodeGen(
- "SELECT count(key) FROM testData3x",
- Row(300) :: Nil)
- // COUNT DISTINCT ON int
- testCodeGen(
- "SELECT value, count(distinct key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 1)))
- testCodeGen(
- "SELECT count(distinct key) FROM testData3x",
- Row(100) :: Nil)
- // SUM
- testCodeGen(
- "SELECT value, sum(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 3 * i)))
- testCodeGen(
- "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
- Row(5050 * 3, 5050 * 3.0) :: Nil)
- // AVERAGE
- testCodeGen(
- "SELECT value, avg(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT avg(key) FROM testData3x",
- Row(50.5) :: Nil)
- // MAX
- testCodeGen(
- "SELECT value, max(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT max(key) FROM testData3x",
- Row(100) :: Nil)
- // MIN
- testCodeGen(
- "SELECT value, min(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT min(key) FROM testData3x",
- Row(1) :: Nil)
- // Some combinations.
- testCodeGen(
- """
- |SELECT
- | value,
- | sum(key),
- | max(key),
- | min(key),
- | avg(key),
- | count(key),
- | count(distinct key)
- |FROM testData3x
- |GROUP BY value
- """.stripMargin,
- (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1)))
- testCodeGen(
- "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
- Row(100, 1, 50.5, 300, 100) :: Nil)
- // Aggregate with Code generation handling all null values
- testCodeGen(
- "SELECT sum('a'), avg('a'), count(null) FROM testData",
- Row(0, null, 0) :: Nil)
-
- dropTempTable("testData3x")
- setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ try {
+ // Just to group rows.
+ testCodeGen(
+ "SELECT key FROM testData3x GROUP BY key",
+ (1 to 100).map(Row(_)))
+ // COUNT
+ testCodeGen(
+ "SELECT key, count(value) FROM testData3x GROUP BY key",
+ (1 to 100).map(i => Row(i, 3)))
+ testCodeGen(
+ "SELECT count(key) FROM testData3x",
+ Row(300) :: Nil)
+ // COUNT DISTINCT ON int
+ testCodeGen(
+ "SELECT value, count(distinct key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 1)))
+ testCodeGen(
+ "SELECT count(distinct key) FROM testData3x",
+ Row(100) :: Nil)
+ // SUM
+ testCodeGen(
+ "SELECT value, sum(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 3 * i)))
+ testCodeGen(
+ "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
+ Row(5050 * 3, 5050 * 3.0) :: Nil)
+ // AVERAGE
+ testCodeGen(
+ "SELECT value, avg(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT avg(key) FROM testData3x",
+ Row(50.5) :: Nil)
+ // MAX
+ testCodeGen(
+ "SELECT value, max(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT max(key) FROM testData3x",
+ Row(100) :: Nil)
+ // MIN
+ testCodeGen(
+ "SELECT value, min(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT min(key) FROM testData3x",
+ Row(1) :: Nil)
+ // Some combinations.
+ testCodeGen(
+ """
+ |SELECT
+ | value,
+ | sum(key),
+ | max(key),
+ | min(key),
+ | avg(key),
+ | count(key),
+ | count(distinct key)
+ |FROM testData3x
+ |GROUP BY value
+ """.stripMargin,
+ (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1)))
+ testCodeGen(
+ "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
+ Row(100, 1, 50.5, 300, 100) :: Nil)
+ // Aggregate with Code generation handling all null values
+ testCodeGen(
+ "SELECT sum('a'), avg('a'), count(null) FROM testData",
+ Row(0, null, 0) :: Nil)
+ } finally {
+ sqlContext.dropTempTable("testData3x")
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ }
}
test("Add Parser of SQL COALESCE()") {
@@ -445,37 +479,43 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("sorting") {
- val before = conf.externalSortEnabled
- setConf(SQLConf.EXTERNAL_SORT, "false")
+ val before = sqlContext.conf.externalSortEnabled
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false")
sortTest()
- setConf(SQLConf.EXTERNAL_SORT, before.toString)
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString)
}
test("external sorting") {
- val before = conf.externalSortEnabled
- setConf(SQLConf.EXTERNAL_SORT, "true")
+ val before = sqlContext.conf.externalSortEnabled
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true")
sortTest()
- setConf(SQLConf.EXTERNAL_SORT, before.toString)
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString)
}
test("SPARK-6927 sorting with codegen on") {
- val externalbefore = conf.externalSortEnabled
- val codegenbefore = conf.codegenEnabled
- setConf(SQLConf.EXTERNAL_SORT, "false")
- setConf(SQLConf.CODEGEN_ENABLED, "true")
- sortTest()
- setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
- setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ val externalbefore = sqlContext.conf.externalSortEnabled
+ val codegenbefore = sqlContext.conf.codegenEnabled
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false")
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
+ try{
+ sortTest()
+ } finally {
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ }
}
test("SPARK-6927 external sorting with codegen on") {
- val externalbefore = conf.externalSortEnabled
- val codegenbefore = conf.codegenEnabled
- setConf(SQLConf.CODEGEN_ENABLED, "true")
- setConf(SQLConf.EXTERNAL_SORT, "true")
- sortTest()
- setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
- setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ val externalbefore = sqlContext.conf.externalSortEnabled
+ val codegenbefore = sqlContext.conf.codegenEnabled
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true")
+ try {
+ sortTest()
+ } finally {
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ }
}
test("limit") {
@@ -508,7 +548,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Allow only a single WITH clause per query") {
intercept[RuntimeException] {
- sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2")
+ sql(
+ "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2")
}
}
@@ -654,7 +695,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
- ignore("cartesian product join") {
+ test("cartesian product join") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
@@ -855,7 +896,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SET commands semantics using sql()") {
- conf.clear()
+ sqlContext.conf.clear()
val testKey = "test.key.0"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
@@ -887,17 +928,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql(s"SET $nonexistentKey"),
Row(s"$nonexistentKey=")
)
- conf.clear()
+ sqlContext.conf.clear()
}
test("SET commands with illegal or inappropriate argument") {
- conf.clear()
+ sqlContext.conf.clear()
// Set negative mapred.reduce.tasks for automatically determing
// the number of reducers is not supported
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2"))
- conf.clear()
+ sqlContext.conf.clear()
}
test("apply schema") {
@@ -915,7 +956,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
+ val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -945,7 +986,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
+ val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -970,7 +1011,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
+ val df3 = sqlContext.createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -1015,7 +1056,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
+ val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
@@ -1030,7 +1071,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SPARK-3371 Renaming a function expression with group by gives error") {
- TestSQLContext.udf.register("len", (s: String) => s.length)
+ sqlContext.udf.register("len", (s: String) => s.length)
checkAnswer(
sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"),
Row(1))
@@ -1211,9 +1252,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SPARK-3483 Special chars in column names") {
- val data = sparkContext.parallelize(
+ val data = sqlContext.sparkContext.parallelize(
Seq("""{"key?number1": "value1", "key.number2": "value2"}"""))
- read.json(data).registerTempTable("records")
+ sqlContext.read.json(data).registerTempTable("records")
sql("SELECT `key?number1`, `key.number2` FROM records")
}
@@ -1254,13 +1295,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SPARK-4322 Grouping field with struct field as sub expression") {
- read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data")
+ sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil))
+ .registerTempTable("data")
checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1))
- dropTempTable("data")
+ sqlContext.dropTempTable("data")
- read.json(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
+ sqlContext.read.json(
+ sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2))
- dropTempTable("data")
+ sqlContext.dropTempTable("data")
}
test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") {
@@ -1279,10 +1322,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Supporting relational operator '<=>' in Spark SQL") {
val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil
- val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i)))
+ val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i)))
rdd1.toDF().registerTempTable("nulldata1")
val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil
- val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i)))
+ val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i)))
rdd2.toDF().registerTempTable("nulldata2")
checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " +
"nulldata2 on nulldata1.value <=> nulldata2.value"),
@@ -1291,22 +1334,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Multi-column COUNT(DISTINCT ...)") {
val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
- val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
+ val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.toDF().registerTempTable("distinctData")
checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))
}
test("SPARK-4699 case sensitivity SQL query") {
- setConf(SQLConf.CASE_SENSITIVE, "false")
+ sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false")
val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
- val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
+ val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.toDF().registerTempTable("testTable1")
checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
- setConf(SQLConf.CASE_SENSITIVE, "true")
+ sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true")
}
test("SPARK-6145: ORDER BY test for nested fields") {
- read.json(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil))
+ sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil))
.registerTempTable("nestedOrder")
checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1))
@@ -1318,17 +1362,90 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SPARK-6145: special cases") {
- read.json(sparkContext.makeRDD(
+ sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
}
test("SPARK-6898: complete support for special chars in column names") {
- read.json(sparkContext.makeRDD(
+ sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil))
.registerTempTable("t")
checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
}
+
+ test("SPARK-6583 order by aggregated function") {
+ Seq("1" -> 3, "1" -> 4, "2" -> 7, "2" -> 8, "3" -> 5, "3" -> 6, "4" -> 1, "4" -> 2)
+ .toDF("a", "b").registerTempTable("orderByData")
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT a
+ |FROM orderByData
+ |GROUP BY a
+ |ORDER BY sum(b)
+ """.stripMargin),
+ Row("4") :: Row("1") :: Row("3") :: Row("2") :: Nil)
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT sum(b)
+ |FROM orderByData
+ |GROUP BY a
+ |ORDER BY sum(b)
+ """.stripMargin),
+ Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil)
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT a, sum(b)
+ |FROM orderByData
+ |GROUP BY a
+ |ORDER BY sum(b)
+ """.stripMargin),
+ Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil)
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT a, sum(b)
+ |FROM orderByData
+ |GROUP BY a
+ |ORDER BY sum(b) + 1
+ """.stripMargin),
+ Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil)
+ }
+
+ test("SPARK-7952: fix the equality check between boolean and numeric types") {
+ withTempTable("t") {
+ // numeric field i, boolean field j, result of i = j, result of i <=> j
+ Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
+ (1, true, true, true),
+ (0, false, true, true),
+ (2, true, false, false),
+ (2, false, false, false),
+ (null, true, null, false),
+ (null, false, null, false),
+ (0, null, null, false),
+ (1, null, null, false),
+ (null, null, null, true)
+ ).toDF("i", "b", "r1", "r2").registerTempTable("t")
+
+ checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
+ checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
+ }
+ }
+
+ test("SPARK-7067: order by queries for complex ExtractValue chain") {
+ withTempTable("t") {
+ sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t")
+ checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1))))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 52d265b445e1..ece3d6fdf2af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -19,10 +19,8 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.test.TestSQLContext._
case class ReflectData(
stringField: String,
@@ -74,17 +72,17 @@ case class ComplexReflectData(
mapFieldContainsNull: Map[Int, Option[Long]],
dataField: Data)
-class ScalaReflectionRelationSuite extends FunSuite {
+class ScalaReflectionRelationSuite extends SparkFunSuite {
- import org.apache.spark.sql.test.TestSQLContext.implicits._
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3))
- val rdd = sparkContext.parallelize(data :: Nil)
- rdd.toDF().registerTempTable("reflectData")
+ Seq(data).toDF().registerTempTable("reflectData")
- assert(sql("SELECT * FROM reflectData").collect().head ===
+ assert(ctx.sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
new Timestamp(12345), Seq(1, 2, 3)))
@@ -92,27 +90,26 @@ class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD with nulls") {
val data = NullReflectData(null, null, null, null, null, null, null)
- val rdd = sparkContext.parallelize(data :: Nil)
- rdd.toDF().registerTempTable("reflectNullData")
+ Seq(data).toDF().registerTempTable("reflectNullData")
- assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
+ assert(ctx.sql("SELECT * FROM reflectNullData").collect().head ===
+ Row.fromSeq(Seq.fill(7)(null)))
}
test("query case class RDD with Nones") {
val data = OptionalReflectData(None, None, None, None, None, None, None)
- val rdd = sparkContext.parallelize(data :: Nil)
- rdd.toDF().registerTempTable("reflectOptionalData")
+ Seq(data).toDF().registerTempTable("reflectOptionalData")
- assert(sql("SELECT * FROM reflectOptionalData").collect().head ===
+ assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head ===
Row.fromSeq(Seq.fill(7)(null)))
}
// Equality is broken for Arrays, so we test that separately.
test("query binary data") {
- val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
- rdd.toDF().registerTempTable("reflectBinary")
+ Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary")
- val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
+ val result = ctx.sql("SELECT data FROM reflectBinary")
+ .collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
}
@@ -128,10 +125,9 @@ class ScalaReflectionRelationSuite extends FunSuite {
Map(10 -> 100L, 20 -> 200L),
Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None),
Nested(None, "abc")))
- val rdd = sparkContext.parallelize(data :: Nil)
- rdd.toDF().registerTempTable("reflectComplexData")
- assert(sql("SELECT * FROM reflectComplexData").collect().head ===
+ Seq(data).toDF().registerTempTable("reflectComplexData")
+ assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head ===
new GenericRow(Array[Any](
Seq(1, 2, 3),
Seq(1, 2, null),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index 6f6d3c9c243d..e55c9e460b79 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -17,16 +17,15 @@
package org.apache.spark.sql
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.sql.test.TestSQLContext
-class SerializationSuite extends FunSuite {
+class SerializationSuite extends SparkFunSuite {
+
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
test("[SPARK-5235] SQLContext should be serializable") {
- val sqlContext = new SQLContext(TestSQLContext.sparkContext)
+ val sqlContext = new SQLContext(ctx.sparkContext)
new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 1a9ba66416b2..703a34c47ec2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,43 +17,83 @@
package org.apache.spark.sql
-import org.apache.spark.sql.test._
-
-/* Implicits */
-import TestSQLContext._
-import TestSQLContext.implicits._
case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
+ test("built-in fixed arity expressions") {
+ val df = ctx.emptyDataFrame
+ df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)")
+ }
+
+ test("built-in vararg expressions") {
+ val df = Seq((1, 2)).toDF("a", "b")
+ df.selectExpr("array(a, b)")
+ df.selectExpr("struct(a, b)")
+ }
+
+ test("built-in expressions with multiple constructors") {
+ val df = Seq(("abcd", 2)).toDF("a", "b")
+ df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect()
+ }
+
+ test("count") {
+ val df = Seq(("abcd", 2)).toDF("a", "b")
+ df.selectExpr("count(a)")
+ }
+
+ test("count distinct") {
+ val df = Seq(("abcd", 2)).toDF("a", "b")
+ df.selectExpr("count(distinct a)")
+ }
+
+ test("error reporting for incorrect number of arguments") {
+ val df = ctx.emptyDataFrame
+ val e = intercept[AnalysisException] {
+ df.selectExpr("substr('abcd', 2, 3, 4)")
+ }
+ assert(e.getMessage.contains("arguments"))
+ }
+
+ test("error reporting for undefined functions") {
+ val df = ctx.emptyDataFrame
+ val e = intercept[AnalysisException] {
+ df.selectExpr("a_function_that_does_not_exist()")
+ }
+ assert(e.getMessage.contains("undefined function"))
+ }
+
test("Simple UDF") {
- udf.register("strLenScala", (_: String).length)
- assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
+ ctx.udf.register("strLenScala", (_: String).length)
+ assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
- udf.register("random0", () => { Math.random()})
- assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
+ ctx.udf.register("random0", () => { Math.random()})
+ assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
- udf.register("strLenScala", (_: String).length + (_: Int))
- assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
+ ctx.udf.register("strLenScala", (_: String).length + (_: Int))
+ assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("struct UDF") {
- udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
+ ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
val result =
- sql("SELECT returnStruct('test', 'test2') as ret")
+ ctx.sql("SELECT returnStruct('test', 'test2') as ret")
.select($"ret.f1").head().getString(0)
assert(result === "test")
}
test("udf that is transformed") {
- udf.register("makeStruct", (x: Int, y: Int) => (x, y))
+ ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
- assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
+ assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index dc2d43a197f4..45c9f06941c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql
-import java.io.File
-
-import org.apache.spark.util.Utils
-
import scala.beans.{BeanInfo, BeanProperty}
import com.clearspring.analytics.stream.cardinality.HyperLogLog
@@ -28,12 +24,11 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashSet
+
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
override def equals(other: Any): Boolean = other match {
@@ -72,11 +67,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
}
class UserDefinedTypeSuite extends QueryTest {
- val points = Seq(
- MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
- MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
- val pointsRDD = sparkContext.parallelize(points).toDF()
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
+ private lazy val pointsRDD = Seq(
+ MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
+ MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF()
test("register user type: MyDenseVector for MyLabeledPoint") {
val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
@@ -94,10 +91,10 @@ class UserDefinedTypeSuite extends QueryTest {
}
test("UDTs and UDFs") {
- TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
+ ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
pointsRDD.registerTempTable("points")
checkAnswer(
- sql("SELECT testType(features) from points"),
+ ctx.sql("SELECT testType(features) from points"),
Seq(Row(true), Row(true)))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 7cefcf44061c..1f37455dd0bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -17,27 +17,29 @@
package org.apache.spark.sql.columnar
-import org.scalatest.FunSuite
-
-import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.InternalRow
import org.apache.spark.sql.types._
-class ColumnStatsSuite extends FunSuite {
- testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0))
- testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0))
- testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0))
- testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0))
- testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
- testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
- testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0))
- testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
- testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0))
- testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
+class ColumnStatsSuite extends SparkFunSuite {
+ testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0))
+ testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0))
+ testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0))
+ testColumnStats(classOf[DoubleColumnStats], DOUBLE,
+ InternalRow(Double.MaxValue, Double.MinValue, 0))
+ testColumnStats(classOf[FixedDecimalColumnStats],
+ FIXED_DECIMAL(15, 10), InternalRow(null, null, 0))
+ testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
+ testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[TimestampColumnStats], TIMESTAMP,
+ InternalRow(Long.MaxValue, Long.MinValue, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
columnType: NativeColumnType[T],
- initialStatistics: Row): Unit = {
+ initialStatistics: InternalRow): Unit = {
val columnStatsName = columnStatsClass.getSimpleName
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 061efb37a0ac..6daddfb2c480 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -18,26 +18,26 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import java.sql.Timestamp
-import com.esotericsoftware.kryo.{Serializer, Kryo}
import com.esotericsoftware.kryo.io.{Input, Output}
-import org.apache.spark.serializer.KryoRegistrator
-import org.scalatest.FunSuite
+import com.esotericsoftware.kryo.{Kryo, Serializer}
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.KryoRegistrator
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
-class ColumnTypeSuite extends FunSuite with Logging {
+class ColumnTypeSuite extends SparkFunSuite with Logging {
val DEFAULT_BUFFER_SIZE = 512
test("defaultSize") {
val checks = Map(
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
- FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12,
+ FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8,
BINARY -> 16, GENERIC -> 16)
checks.foreach { case (columnType, expectedSize) =>
@@ -68,9 +68,9 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(BOOLEAN, true, 1)
- checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length)
+ checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length)
checkActualSize(DATE, 0, 4)
- checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
+ checkActualSize(TIMESTAMP, 0L, 8)
val binary = Array.fill[Byte](4)(0: Byte)
checkActualSize(BINARY, binary, 4 + 4)
@@ -120,7 +120,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes)
- UTF8String(bytes)
+ UTF8String.fromBytes(bytes)
})
testColumnType[BinaryType.type, Array[Byte]](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index 75d993e563e0..7c86eae3f77f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -17,14 +17,12 @@
package org.apache.spark.sql.columnar
-import java.sql.Timestamp
-
import scala.collection.immutable.HashSet
import scala.util.Random
-
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType}
+import org.apache.spark.sql.types.{DataType, Decimal, AtomicType}
+import org.apache.spark.unsafe.types.UTF8String
object ColumnarTestUtils {
def makeNullRow(length: Int): GenericMutableRow = {
@@ -48,14 +46,11 @@ object ColumnarTestUtils {
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
- case STRING => UTF8String(Random.nextString(Random.nextInt(32)))
+ case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
case DATE => Random.nextInt()
- case TIMESTAMP =>
- val timestamp = new Timestamp(Random.nextLong())
- timestamp.setNanos(Random.nextInt(999999999))
- timestamp
+ case TIMESTAMP => Random.nextLong()
case _ =>
// Using a random one-element map instead of an arbitrary object
Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
@@ -81,9 +76,9 @@ object ColumnarTestUtils {
def makeRandomRow(
head: ColumnType[_ <: DataType, _],
- tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail)
+ tail: ColumnType[_ <: DataType, _]*): InternalRow = makeRandomRow(Seq(head) ++ tail)
- def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = {
+ def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): InternalRow = {
val row = new GenericMutableRow(columnTypes.length)
makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
row(index) = value
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 56591d9dba29..12f95eb557c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -20,19 +20,20 @@ package org.apache.spark.sql.columnar
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, TestData}
+import org.apache.spark.sql.{QueryTest, Row, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+ import ctx.{logicalPlanToSparkQuery, sql}
+
test("simple columnar query") {
- val plan = executePlan(testData.logicalPlan).executedPlan
+ val plan = ctx.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -40,16 +41,16 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
- sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
.toDF().registerTempTable("sizeTst")
- cacheTable("sizeTst")
+ ctx.cacheTable("sizeTst")
assert(
- table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
- conf.autoBroadcastJoinThreshold)
+ ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
+ ctx.conf.autoBroadcastJoinThreshold)
}
test("projection") {
- val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan
+ val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().map {
@@ -58,7 +59,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
- val plan = executePlan(testData.logicalPlan).executedPlan
+ val plan = ctx.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -70,7 +71,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM repeatedData"),
repeatedData.collect().toSeq.map(Row.fromTuple))
- cacheTable("repeatedData")
+ ctx.cacheTable("repeatedData")
checkAnswer(
sql("SELECT * FROM repeatedData"),
@@ -82,7 +83,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM nullableRepeatedData"),
nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
- cacheTable("nullableRepeatedData")
+ ctx.cacheTable("nullableRepeatedData")
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
@@ -94,7 +95,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT time FROM timestamps"),
timestamps.collect().toSeq.map(Row.fromTuple))
- cacheTable("timestamps")
+ ctx.cacheTable("timestamps")
checkAnswer(
sql("SELECT time FROM timestamps"),
@@ -106,7 +107,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM withEmptyParts"),
withEmptyParts.collect().toSeq.map(Row.fromTuple))
- cacheTable("withEmptyParts")
+ ctx.cacheTable("withEmptyParts")
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
@@ -155,7 +156,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Create a RDD for the schema
val rdd =
- sparkContext.parallelize((1 to 100), 10).map { i =>
+ ctx.sparkContext.parallelize((1 to 100), 10).map { i =>
Row(
s"str${i}: test cache.",
s"binary${i}: test cache.".getBytes("UTF-8"),
@@ -173,20 +174,20 @@ class InMemoryColumnarQuerySuite extends QueryTest {
new Timestamp(i),
(1 to i).toSeq,
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
- Row((i - 0.25).toFloat, (1 to i).toSeq))
+ Row((i - 0.25).toFloat, Seq(true, false, null)))
}
- createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
+ ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.
sql("cache table InMemoryCache_different_data_types")
// Make sure the table is indeed cached.
- val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan
+ val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan
assert(
- isCached("InMemoryCache_different_data_types"),
+ ctx.isCached("InMemoryCache_different_data_types"),
"InMemoryCache_different_data_types should be cached.")
// Issue a query and check the results.
checkAnswer(
sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"),
- table("InMemoryCache_different_data_types").collect())
- dropTempTable("InMemoryCache_different_data_types")
+ ctx.table("InMemoryCache_different_data_types").collect())
+ ctx.dropTempTable("InMemoryCache_different_data_types")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index a0702144f942..2a6e0c376551 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.DataType
@@ -39,7 +38,7 @@ object TestNullableColumnAccessor {
}
}
-class NullableColumnAccessorSuite extends FunSuite {
+class NullableColumnAccessorSuite extends SparkFunSuite {
import ColumnarTestUtils._
Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index 3a5605d2335d..cb4e9f1eb7f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.columnar
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
@@ -35,7 +34,7 @@ object TestNullableColumnBuilder {
}
}
-class NullableColumnBuilderSuite extends FunSuite {
+class NullableColumnBuilderSuite extends SparkFunSuite {
import ColumnarTestUtils._
Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 2a0b701cad7f..6545c6b314a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -17,43 +17,46 @@
package org.apache.spark.sql.columnar
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
- val originalColumnBatchSize = conf.columnBatchSize
- val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning
+class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
+
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
+ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
+ private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
- setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
+ ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
- val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
+ val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
}, 5).toDF()
pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
- setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+ ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
// Enable in-memory table scan accumulators
- setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
+ ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
}
override protected def afterAll(): Unit = {
- setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
- setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
+ ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
+ ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
}
before {
- cacheTable("pruningData")
+ ctx.cacheTable("pruningData")
}
after {
- uncacheTable("pruningData")
+ ctx.uncacheTable("pruningData")
}
// Comparisons
@@ -107,7 +110,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
expectedQueryResult: => Seq[Int]): Unit = {
test(query) {
- val df = sql(query)
+ val df = ctx.sql(query)
val queryExecution = df.queryExecution
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index 8b518f094174..f606e2133bed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.columnar.compression
-import org.scalatest.FunSuite
-
-import org.apache.spark.sql.Row
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN}
import org.apache.spark.sql.columnar.ColumnarTestUtils._
+import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats}
-class BooleanBitSetSuite extends FunSuite {
+class BooleanBitSetSuite extends SparkFunSuite {
import BooleanBitSet._
def skeleton(count: Int) {
@@ -33,7 +32,7 @@ class BooleanBitSetSuite extends FunSuite {
// -------------
val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet)
- val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN))
+ val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN))
val values = rows.map(_(0))
rows.foreach(builder.appendFrom(_, 0))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index cef60ec204fa..acfab6586c0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql.columnar.compression
import java.nio.ByteBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types.AtomicType
-class DictionaryEncodingSuite extends FunSuite {
+class DictionaryEncodingSuite extends SparkFunSuite {
testDictionaryEncoding(new IntColumnStats, INT)
testDictionaryEncoding(new LongColumnStats, LONG)
testDictionaryEncoding(new StringColumnStats, STRING)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 5514590541dd..2111e9fbe62c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.columnar.compression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types.IntegralType
-class IntegralDeltaSuite extends FunSuite {
+class IntegralDeltaSuite extends SparkFunSuite {
testIntegralDelta(new IntColumnStats, INT, IntDelta)
testIntegralDelta(new LongColumnStats, LONG, LongDelta)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 6ee48f629191..67ec08f594a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.columnar.compression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types.AtomicType
-class RunLengthEncodingSuite extends FunSuite {
+class RunLengthEncodingSuite extends SparkFunSuite {
testRunLengthEncoding(new NoopColumnStats, BOOLEAN)
testRunLengthEncoding(new ByteColumnStats, BYTE)
testRunLengthEncoding(new ShortColumnStats, SHORT)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 523be56df65b..3e27f58a92d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -17,21 +17,19 @@
package org.apache.spark.sql.execution
-import org.scalatest.FunSuite
-
-import org.apache.spark.sql.{SQLConf, execution}
-import org.apache.spark.sql.functions._
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
+import org.apache.spark.sql.{Row, SQLConf, execution}
-class PlannerSuite extends FunSuite {
+class PlannerSuite extends SparkFunSuite {
test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
index 15337c404543..8631e247c6c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -19,17 +19,17 @@ package org.apache.spark.sql.execution
import java.sql.{Timestamp, Date}
-import org.scalatest.{FunSuite, BeforeAndAfterAll}
+import org.apache.spark.sql.test.TestSQLContext
+import org.scalatest.BeforeAndAfterAll
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.Serializer
-import org.apache.spark.ShuffleDependency
+import org.apache.spark.{ShuffleDependency, SparkFunSuite}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
-import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}
-class SparkSqlSerializer2DataTypeSuite extends FunSuite {
+class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite {
// Make sure that we will not use serializer2 for unsupported data types.
def checkSupported(dataType: DataType, isSupported: Boolean): Unit = {
val testName =
@@ -74,11 +74,13 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
var numShufflePartitions: Int = _
var useSerializer2: Boolean = _
+ protected lazy val ctx = TestSQLContext
+
override def beforeAll(): Unit = {
- numShufflePartitions = conf.numShufflePartitions
- useSerializer2 = conf.useSqlSerializer2
+ numShufflePartitions = ctx.conf.numShufflePartitions
+ useSerializer2 = ctx.conf.useSqlSerializer2
- sql("set spark.sql.useSerializer2=true")
+ ctx.sql("set spark.sql.useSerializer2=true")
val supportedTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
@@ -94,7 +96,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
// Create a RDD with all data types supported by SparkSqlSerializer2.
val rdd =
- sparkContext.parallelize((1 to 1000), 10).map { i =>
+ ctx.sparkContext.parallelize((1 to 1000), 10).map { i =>
Row(
s"str${i}: test serializer2.",
s"binary${i}: test serializer2.".getBytes("UTF-8"),
@@ -112,15 +114,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
new Timestamp(i))
}
- createDataFrame(rdd, schema).registerTempTable("shuffle")
+ ctx.createDataFrame(rdd, schema).registerTempTable("shuffle")
super.beforeAll()
}
override def afterAll(): Unit = {
- dropTempTable("shuffle")
- sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
- sql(s"set spark.sql.useSerializer2=$useSerializer2")
+ ctx.dropTempTable("shuffle")
+ ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
+ ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2")
super.afterAll()
}
@@ -141,16 +143,16 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
}
test("key schema and value schema are not nulls") {
- val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
+ val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
- table("shuffle").collect())
+ ctx.table("shuffle").collect())
}
test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
- val df = sql(s"SELECT $aggregations FROM shuffle")
+ val df = ctx.sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
@@ -158,15 +160,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
}
test("value schema is null") {
- val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
+ val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
- assert(
- df.map(r => r.getString(0)).collect().toSeq ===
- table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
+ assert(df.map(r => r.getString(0)).collect().toSeq ===
+ ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
}
test("no map output field") {
- val df = sql(s"SELECT 1 + 1 FROM shuffle")
+ val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer])
}
}
@@ -177,8 +178,8 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
super.beforeAll()
// Sort merge will not be triggered.
val bypassMergeThreshold =
- sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
+ ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
}
}
@@ -189,7 +190,7 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite
super.beforeAll()
// To trigger the sort merge.
val bypassMergeThreshold =
- sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
+ ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 358d8cf06e46..8ec3985e0036 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.execution.debug
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext._
-class DebuggingSuite extends FunSuite {
+class DebuggingSuite extends SparkFunSuite {
test("DataFrame.debug()") {
testData.debug()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 2aad01ded1ac..71db6a215985 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -17,47 +17,46 @@
package org.apache.spark.sql.execution.joins
-import org.scalatest.FunSuite
-
-import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.{Projection, InternalRow}
import org.apache.spark.util.collection.CompactBuffer
-class HashedRelationSuite extends FunSuite {
+class HashedRelationSuite extends SparkFunSuite {
// Key is simply the record itself
private val keyProjection = new Projection {
- override def apply(row: Row): Row = row
+ override def apply(row: InternalRow): InternalRow = row
}
test("GeneralHashedRelation") {
- val data = Array(Row(0), Row(1), Row(2), Row(2))
+ val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])
- assert(hashed.get(data(0)) == CompactBuffer[Row](data(0)))
- assert(hashed.get(data(1)) == CompactBuffer[Row](data(1)))
- assert(hashed.get(Row(10)) === null)
+ assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
+ assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
+ assert(hashed.get(InternalRow(10)) === null)
- val data2 = CompactBuffer[Row](data(2))
+ val data2 = CompactBuffer[InternalRow](data(2))
data2 += data(2)
assert(hashed.get(data(2)) == data2)
}
test("UniqueKeyHashedRelation") {
- val data = Array(Row(0), Row(1), Row(2))
+ val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
- assert(hashed.get(data(0)) == CompactBuffer[Row](data(0)))
- assert(hashed.get(data(1)) == CompactBuffer[Row](data(1)))
- assert(hashed.get(data(2)) == CompactBuffer[Row](data(2)))
- assert(hashed.get(Row(10)) === null)
+ assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
+ assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
+ assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2)))
+ assert(hashed.get(InternalRow(10)) === null)
val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
assert(uniqHashed.getValue(data(0)) == data(0))
assert(uniqHashed.getValue(data(1)) == data(1))
assert(uniqHashed.getValue(data(2)) == data(2))
- assert(uniqHashed.getValue(Row(10)) == null)
+ assert(uniqHashed.getValue(InternalRow(10)) == null)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 30279f528944..69ab1c292d22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -21,14 +21,13 @@ import java.math.BigDecimal
import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar, Properties}
-import org.apache.spark.sql.test._
-import org.apache.spark.sql.types._
import org.h2.jdbc.JdbcSQLException
-import org.scalatest.{FunSuite, BeforeAndAfter}
-import TestSQLContext._
-import TestSQLContext.implicits._
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
-class JDBCSuite extends FunSuite with BeforeAndAfter {
+class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb0"
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null
@@ -36,12 +35,16 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
val testH2Dialect = new JdbcDialect {
- def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
+ override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
Some(StringType)
}
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+ import ctx.sql
+
before {
Class.forName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test
@@ -67,7 +70,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
-
+
sql(
s"""
|CREATE TEMPORARY TABLE fetchtwo
@@ -75,7 +78,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass',
| fetchSize '2')
""".stripMargin.replaceAll("\n", " "))
-
+
sql(
s"""
|CREATE TEMPORARY TABLE parts
@@ -208,7 +211,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(ids(1) === 2)
assert(ids(2) === 3)
}
-
+
test("SELECT second field when fetchSize is two") {
val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _)
assert(ids.size === 3)
@@ -252,26 +255,26 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}
test("Basic API") {
- assert(TestSQLContext.read.jdbc(
+ assert(ctx.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3)
}
test("Basic API with FetchSize") {
val properties = new Properties
properties.setProperty("fetchSize", "2")
- assert(TestSQLContext.read.jdbc(
+ assert(ctx.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
}
test("Partitioning via JDBCPartitioningInfo API") {
assert(
- TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
+ ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
.collect().length === 3)
}
test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
- assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
+ assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
.collect().length === 3)
}
@@ -323,13 +326,13 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(cal.get(Calendar.HOUR) === 11)
assert(cal.get(Calendar.MINUTE) === 22)
assert(cal.get(Calendar.SECOND) === 33)
- assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543)
+ assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543500)
}
test("test DATE types") {
- val rows = TestSQLContext.read.jdbc(
+ val rows = ctx.read.jdbc(
urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(1).getAs[java.sql.Date](1) === null)
@@ -337,9 +340,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}
test("test DATE types in cache") {
- val rows =
- TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
+ ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().registerTempTable("mycached_date")
val cachedRows = sql("select * from mycached_date").collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
@@ -347,7 +349,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}
test("test types for null value") {
- val rows = TestSQLContext.read.jdbc(
+ val rows = ctx.read.jdbc(
urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect()
assert((0 to 14).forall(i => rows(0).isNullAt(i)))
}
@@ -394,10 +396,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
- val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
- assert(df.schema.filter(
- _.dataType != org.apache.spark.sql.types.StringType
- ).isEmpty)
+ val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+ assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty)
val rows = df.collect()
assert(rows(0).get(0).isInstanceOf[String])
assert(rows(0).get(1).isInstanceOf[String])
@@ -410,6 +410,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(JdbcDialects.get("test.invalid") == NoopDialect)
}
+ test("quote column names by jdbc dialect") {
+ val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
+ val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
+
+ val columns = Seq("abc", "key")
+ val MySQLColumns = columns.map(MySQL.quoteIdentifier(_))
+ val PostgresColumns = columns.map(Postgres.quoteIdentifier(_))
+ assert(MySQLColumns === Seq("`abc`", "`key`"))
+ assert(PostgresColumns === Seq(""""abc"""", """"key""""))
+ }
+
test("Dialect unregister") {
JdbcDialects.registerDialect(testH2Dialect)
JdbcDialects.unregisterDialect(testH2Dialect)
@@ -418,7 +429,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
test("Aggregated dialects") {
val agg = new AggregatedDialect(List(new JdbcDialect {
- def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
+ override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
if (sqlType % 2 == 0) {
@@ -429,8 +440,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}, testH2Dialect))
assert(agg.canHandle("jdbc:h2:xxx"))
assert(!agg.canHandle("jdbc:h2"))
- assert(agg.getCatalystType(0, "", 1, null) == Some(LongType))
- assert(agg.getCatalystType(1, "", 1, null) == Some(StringType))
+ assert(agg.getCatalystType(0, "", 1, null) === Some(LongType))
+ assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 2e4c12f9da80..d949ef42267e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.jdbc
import java.sql.DriverManager
import java.util.Properties
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{SaveMode, Row}
-import org.apache.spark.sql.test._
import org.apache.spark.sql.types._
-class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
+class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb2"
var conn: java.sql.Connection = null
val url1 = "jdbc:h2:mem:testdb3"
@@ -35,12 +35,16 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
properties.setProperty("user", "testUser")
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
-
+
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+ import ctx.sql
+
before {
Class.forName("org.h2.Driver")
conn = DriverManager.getConnection(url)
conn.prepareStatement("create schema test").executeUpdate()
-
+
conn1 = DriverManager.getConnection(url1, properties)
conn1.prepareStatement("create schema test").executeUpdate()
conn1.prepareStatement("drop table if exists test.people").executeUpdate()
@@ -52,20 +56,20 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
conn1.prepareStatement(
"create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn1.commit()
-
- TestSQLContext.sql(
+
+ ctx.sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
-
- TestSQLContext.sql(
+
+ ctx.sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE1
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass')
- """.stripMargin.replaceAll("\n", " "))
+ """.stripMargin.replaceAll("\n", " "))
}
after {
@@ -73,66 +77,64 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
conn1.close()
}
- val sc = TestSQLContext.sparkContext
+ private lazy val sc = ctx.sparkContext
- val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
- val arr1x2 = Array[Row](Row.apply("fred", 3))
- val schema2 = StructType(
+ private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
+ private lazy val arr1x2 = Array[Row](Row.apply("fred", 3))
+ private lazy val schema2 = StructType(
StructField("name", StringType) ::
StructField("id", IntegerType) :: Nil)
- val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2))
- val schema3 = StructType(
+ private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2))
+ private lazy val schema3 = StructType(
StructField("name", StringType) ::
StructField("id", IntegerType) ::
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties)
- assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
- assert(2 ==
- TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
+ assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
+ assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
}
test("CREATE with overwrite") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
- val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.DROPTEST", properties)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
- assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
+ assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties)
- assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
+ assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
}
test("CREATE then INSERT to append") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url, "TEST.APPENDTEST", new Properties)
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties)
- assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
- assert(2 ==
- TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
+ assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
+ assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
}
test("CREATE then INSERT to truncate") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties)
- assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
+ assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
+ assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
}
test("Incompatible INSERT to append") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3)
df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties)
intercept[org.apache.spark.SparkException] {
@@ -141,15 +143,15 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("INSERT to JDBC Datasource") {
- TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
- TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
- TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
- }
+ ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index f8d62f9e7e02..fca24364fe6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -23,21 +23,19 @@ import java.sql.{Date, Timestamp}
import com.fasterxml.jackson.core.JsonFactory
import org.scalactic.Tolerance._
+import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.json.InferSchema.compatibleType
import org.apache.spark.sql.sources.LogicalRelation
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.util.Utils
-class JsonSuite extends QueryTest {
- import org.apache.spark.sql.json.TestJsonData._
+class JsonSuite extends QueryTest with TestJsonData {
- TestJsonData
+ protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.sql
+ import ctx.implicits._
test("Type promotion") {
def checkTypePromotion(expected: Any, actual: Any) {
@@ -78,21 +76,25 @@ class JsonSuite extends QueryTest {
checkTypePromotion(
Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited))
- checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType))
- checkTypePromotion(new Timestamp(intNumber.toLong),
+ checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber)),
+ enforceCorrectType(intNumber, TimestampType))
+ checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)),
enforceCorrectType(intNumber.toLong, TimestampType))
val strTime = "2014-09-30 12:34:56"
- checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType))
+ checkTypePromotion(DateUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)),
+ enforceCorrectType(strTime, TimestampType))
val strDate = "2014-10-15"
checkTypePromotion(
DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType))
val ISO8601Time1 = "1970-01-01T01:00:01.0Z"
- checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType))
+ checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(3601000)),
+ enforceCorrectType(ISO8601Time1, TimestampType))
checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType))
val ISO8601Time2 = "1970-01-01T02:00:01-01:00"
- checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType))
+ checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(10801000)),
+ enforceCorrectType(ISO8601Time2, TimestampType))
checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType))
}
@@ -214,7 +216,7 @@ class JsonSuite extends QueryTest {
}
test("Complex field and type inferring with null in sampling") {
- val jsonDF = read.json(jsonNullStruct)
+ val jsonDF = ctx.read.json(jsonNullStruct)
val expectedSchema = StructType(
StructField("headers", StructType(
StructField("Charset", StringType, true) ::
@@ -233,7 +235,7 @@ class JsonSuite extends QueryTest {
}
test("Primitive field and type inferring") {
- val jsonDF = read.json(primitiveFieldAndType)
+ val jsonDF = ctx.read.json(primitiveFieldAndType)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
@@ -261,7 +263,7 @@ class JsonSuite extends QueryTest {
}
test("Complex field and type inferring") {
- val jsonDF = read.json(complexFieldAndType1)
+ val jsonDF = ctx.read.json(complexFieldAndType1)
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
@@ -360,7 +362,7 @@ class JsonSuite extends QueryTest {
}
test("GetField operation on complex data type") {
- val jsonDF = read.json(complexFieldAndType1)
+ val jsonDF = ctx.read.json(complexFieldAndType1)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -376,7 +378,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in primitive field values") {
- val jsonDF = read.json(primitiveFieldValueTypeConflict)
+ val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
@@ -450,7 +452,7 @@ class JsonSuite extends QueryTest {
}
ignore("Type conflict in primitive field values (Ignored)") {
- val jsonDF = read.json(primitiveFieldValueTypeConflict)
+ val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict)
jsonDF.registerTempTable("jsonTable")
// Right now, the analyzer does not promote strings in a boolean expression.
@@ -503,7 +505,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in complex field values") {
- val jsonDF = read.json(complexFieldValueTypeConflict)
+ val jsonDF = ctx.read.json(complexFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("array", ArrayType(LongType, true), true) ::
@@ -527,7 +529,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in array elements") {
- val jsonDF = read.json(arrayElementTypeConflict)
+ val jsonDF = ctx.read.json(arrayElementTypeConflict)
val expectedSchema = StructType(
StructField("array1", ArrayType(StringType, true), true) ::
@@ -555,7 +557,7 @@ class JsonSuite extends QueryTest {
}
test("Handling missing fields") {
- val jsonDF = read.json(missingFields)
+ val jsonDF = ctx.read.json(missingFields)
val expectedSchema = StructType(
StructField("a", BooleanType, true) ::
@@ -574,8 +576,9 @@ class JsonSuite extends QueryTest {
val dir = Utils.createTempDir()
dir.delete()
val path = dir.getCanonicalPath
- sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
- val jsonDF = read.option("samplingRatio", "0.49").json(path)
+ ctx.sparkContext.parallelize(1 to 100)
+ .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
+ val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path)
val analyzed = jsonDF.queryExecution.analyzed
assert(
@@ -590,7 +593,7 @@ class JsonSuite extends QueryTest {
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
- read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
+ ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
assert(relationWithSchema.path === Some(path))
assert(relationWithSchema.schema === schema)
@@ -602,7 +605,7 @@ class JsonSuite extends QueryTest {
dir.delete()
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
- val jsonDF = read.json(path)
+ val jsonDF = ctx.read.json(path)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
@@ -671,7 +674,7 @@ class JsonSuite extends QueryTest {
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)
- val jsonDF1 = read.schema(schema).json(path)
+ val jsonDF1 = ctx.read.schema(schema).json(path)
assert(schema === jsonDF1.schema)
@@ -688,7 +691,7 @@ class JsonSuite extends QueryTest {
"this is a simple string.")
)
- val jsonDF2 = read.schema(schema).json(primitiveFieldAndType)
+ val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType)
assert(schema === jsonDF2.schema)
@@ -709,7 +712,7 @@ class JsonSuite extends QueryTest {
test("Applying schemas with MapType") {
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
- val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1)
+ val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1)
jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap")
@@ -737,7 +740,7 @@ class JsonSuite extends QueryTest {
val schemaWithComplexMap = StructType(
StructField("map", MapType(StringType, innerStruct, true), false) :: Nil)
- val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2)
+ val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2)
jsonWithComplexMap.registerTempTable("jsonWithComplexMap")
@@ -763,7 +766,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-2096 Correctly parse dot notations") {
- val jsonDF = read.json(complexFieldAndType2)
+ val jsonDF = ctx.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -781,7 +784,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-3390 Complex arrays") {
- val jsonDF = read.json(complexFieldAndType2)
+ val jsonDF = ctx.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -804,7 +807,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-3308 Read top level JSON arrays") {
- val jsonDF = read.json(jsonArray)
+ val jsonDF = ctx.read.json(jsonArray)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -822,10 +825,10 @@ class JsonSuite extends QueryTest {
test("Corrupt records") {
// Test if we can query corrupt records.
- val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord
- TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
+ val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
+ ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
- val jsonDF = read.json(corruptRecords)
+ val jsonDF = ctx.read.json(corruptRecords)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
@@ -875,11 +878,11 @@ class JsonSuite extends QueryTest {
Row("]") :: Nil
)
- TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
+ ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
test("SPARK-4068: nulls in arrays") {
- val jsonDF = read.json(nullsInArrays)
+ val jsonDF = ctx.read.json(nullsInArrays)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
@@ -925,7 +928,7 @@ class JsonSuite extends QueryTest {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = createDataFrame(rowRDD1, schema1)
+ val df1 = ctx.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDF
val result = df2.toJSON.collect()
@@ -948,7 +951,7 @@ class JsonSuite extends QueryTest {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = createDataFrame(rowRDD2, schema2)
+ val df3 = ctx.createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDF
val result2 = df4.toJSON.collect()
@@ -956,8 +959,8 @@ class JsonSuite extends QueryTest {
assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
- val jsonDF = read.json(primitiveFieldAndType)
- val primTable = read.json(jsonDF.toJSON)
+ val jsonDF = ctx.read.json(primitiveFieldAndType)
+ val primTable = ctx.read.json(jsonDF.toJSON)
primTable.registerTempTable("primativeTable")
checkAnswer(
sql("select * from primativeTable"),
@@ -969,8 +972,8 @@ class JsonSuite extends QueryTest {
"this is a simple string.")
)
- val complexJsonDF = read.json(complexFieldAndType1)
- val compTable = read.json(complexJsonDF.toJSON)
+ val complexJsonDF = ctx.read.json(complexFieldAndType1)
+ val compTable = ctx.read.json(complexJsonDF.toJSON)
compTable.registerTempTable("complexTable")
// Access elements of a primitive array.
checkAnswer(
@@ -1074,29 +1077,29 @@ class JsonSuite extends QueryTest {
}
test("SPARK-7565 MapType in JsonRDD") {
- val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true")
- val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord
- TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
+ val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true")
+ val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
+ ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
try{
for (useStreaming <- List("true", "false")) {
- setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
+ ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
val temp = Utils.createTempDir().getPath
- val df = read.schema(schemaWithSimpleMap).json(mapType1)
+ val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1)
df.write.mode("overwrite").parquet(temp)
// order of MapType is not defined
- assert(read.parquet(temp).count() == 5)
+ assert(ctx.read.parquet(temp).count() == 5)
- val df2 = read.json(corruptRecords)
+ val df2 = ctx.read.json(corruptRecords)
df2.write.mode("overwrite").parquet(temp)
- checkAnswer(read.parquet(temp), df2.collect())
+ checkAnswer(ctx.read.parquet(temp), df2.collect())
}
} finally {
- setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
- setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
+ ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
+ ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
index 47a97a49daab..b6a6a8dc6a63 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.json
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
-object TestJsonData {
+trait TestJsonData {
- val primitiveFieldAndType =
- TestSQLContext.sparkContext.parallelize(
+ protected def ctx: SQLContext
+
+ def primitiveFieldAndType: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
@@ -32,8 +35,8 @@ object TestJsonData {
"null":null
}""" :: Nil)
- val primitiveFieldValueTypeConflict =
- TestSQLContext.sparkContext.parallelize(
+ def primitiveFieldValueTypeConflict: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@@ -43,15 +46,15 @@ object TestJsonData {
"""{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470,
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
- val jsonNullStruct =
- TestSQLContext.sparkContext.parallelize(
+ def jsonNullStruct: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
- val complexFieldValueTypeConflict =
- TestSQLContext.sparkContext.parallelize(
+ def complexFieldValueTypeConflict: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
@@ -61,23 +64,23 @@ object TestJsonData {
"""{"num_struct":{}, "str_array":["str1", "str2", 33],
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
- val arrayElementTypeConflict =
- TestSQLContext.sparkContext.parallelize(
+ def arrayElementTypeConflict: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)
- val missingFields =
- TestSQLContext.sparkContext.parallelize(
+ def missingFields: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
"""{"d":{"field":true}}""" ::
"""{"e":"str"}""" :: Nil)
- val complexFieldAndType1 =
- TestSQLContext.sparkContext.parallelize(
+ def complexFieldAndType1: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
@@ -92,8 +95,8 @@ object TestJsonData {
"arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]]
}""" :: Nil)
- val complexFieldAndType2 =
- TestSQLContext.sparkContext.parallelize(
+ def complexFieldAndType2: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
@@ -146,16 +149,16 @@ object TestJsonData {
]]
}""" :: Nil)
- val mapType1 =
- TestSQLContext.sparkContext.parallelize(
+ def mapType1: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
"""{"map": {"c": 1, "d": 4}}""" ::
"""{"map": {"e": null}}""" :: Nil)
- val mapType2 =
- TestSQLContext.sparkContext.parallelize(
+ def mapType2: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@@ -163,22 +166,22 @@ object TestJsonData {
"""{"map": {"e": null}}""" ::
"""{"map": {"f": {"field1": null}}}""" :: Nil)
- val nullsInArrays =
- TestSQLContext.sparkContext.parallelize(
+ def nullsInArrays: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
- val jsonArray =
- TestSQLContext.sparkContext.parallelize(
+ def jsonArray: RDD[String] =
+ ctx.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
- val corruptRecords =
- TestSQLContext.sparkContext.parallelize(
+ def corruptRecords: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
@@ -186,6 +189,5 @@ object TestJsonData {
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""]""" :: Nil)
- val empty =
- TestSQLContext.sparkContext.parallelize(Seq[String]())
+ def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index bdc2ebabc5e9..fa5d4eca05d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -18,16 +18,15 @@
package org.apache.spark.sql.parquet
import org.scalatest.BeforeAndAfterAll
-import parquet.filter2.predicate.Operators._
-import parquet.filter2.predicate.{FilterPredicate, Operators}
+import org.apache.parquet.filter2.predicate.Operators._
+import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.sources.LogicalRelation
-import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf}
+import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -42,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf}
* data type is nullable.
*/
class ParquetFilterSuiteBase extends QueryTest with ParquetTest {
- val sqlContext = TestSQLContext
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
private def checkFilterPredicate(
df: DataFrame,
@@ -312,7 +311,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest {
}
class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
@@ -341,7 +340,7 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA
}
class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index dd48bb350f26..fc827bc4ca11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -23,24 +23,22 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
+import org.apache.parquet.example.data.simple.SimpleGroup
+import org.apache.parquet.example.data.{Group, GroupWriter}
+import org.apache.parquet.hadoop.api.WriteSupport
+import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
+import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata}
+import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter}
+import org.apache.parquet.io.api.RecordConsumer
+import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.scalatest.BeforeAndAfterAll
-import parquet.example.data.simple.SimpleGroup
-import parquet.example.data.{Group, GroupWriter}
-import parquet.hadoop.api.WriteSupport
-import parquet.hadoop.api.WriteSupport.WriteContext
-import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName}
-import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter}
-import parquet.io.api.RecordConsumer
-import parquet.schema.{MessageType, MessageTypeParser}
+import org.apache.spark.SparkException
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.util.DateUtils
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode}
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
// with an empty configuration (it is after all not intended to be used in this way?)
@@ -66,9 +64,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
* A test suite that tests basic Parquet I/O.
*/
class ParquetIOSuiteBase extends QueryTest with ParquetTest {
- val sqlContext = TestSQLContext
-
- import sqlContext.implicits.localSeqToDataFrameHolder
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+ import sqlContext.implicits._
/**
* Writes `data` to a Parquet file, reads it back and check file contents.
@@ -104,7 +101,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
test("fixed-length decimals") {
def makeDecimalRDD(decimal: DecimalType): DataFrame =
- sparkContext
+ sqlContext.sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(i / 100.0))
.toDF()
@@ -115,7 +112,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempPath { dir =>
val data = makeDecimalRDD(DecimalType(precision, scale))
data.write.parquet(dir.getCanonicalPath)
- checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq)
+ checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
}
}
@@ -123,7 +120,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
intercept[Throwable] {
withTempPath { dir =>
makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath)
- read.parquet(dir.getCanonicalPath).collect()
+ sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
@@ -131,14 +128,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
intercept[Throwable] {
withTempPath { dir =>
makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath)
- read.parquet(dir.getCanonicalPath).collect()
+ sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
}
test("date type") {
def makeDateRDD(): DataFrame =
- sparkContext
+ sqlContext.sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(DateUtils.toJavaDate(i)))
.toDF()
@@ -147,7 +144,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempPath { dir =>
val data = makeDateRDD()
data.write.parquet(dir.getCanonicalPath)
- checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq)
+ checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
}
}
@@ -200,7 +197,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetDataFrame(allNulls :: Nil) { df =>
val rows = df.collect()
- assert(rows.size === 1)
+ assert(rows.length === 1)
assert(rows.head === Row(Seq.fill(5)(null): _*))
}
}
@@ -213,7 +210,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetDataFrame(allNones :: Nil) { df =>
val rows = df.collect()
- assert(rows.size === 1)
+ assert(rows.length === 1)
assert(rows.head === Row(Seq.fill(3)(null): _*))
}
}
@@ -236,7 +233,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
def checkCompressionCodec(codec: CompressionCodecName): Unit = {
withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) {
withParquetFile(data) { path =>
- assertResult(conf.parquetCompressionCodec.toUpperCase) {
+ assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) {
compressionCodecFor(path)
}
}
@@ -244,7 +241,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
}
// Checks default compression codec
- checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec))
+ checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec))
checkCompressionCodec(CompressionCodecName.UNCOMPRESSED)
checkCompressionCodec(CompressionCodecName.GZIP)
@@ -283,7 +280,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
makeRawParquetFile(path)
- checkAnswer(read.parquet(path.toString), (0 until 10).map { i =>
+ checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i =>
Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
})
}
@@ -312,7 +309,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile((1 to 10).map(i => (i, i.toString))) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file)
- checkAnswer(read.parquet(file), newData.map(Row.fromTuple))
+ checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple))
}
}
@@ -321,7 +318,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file)
- checkAnswer(read.parquet(file), data.map(Row.fromTuple))
+ checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple))
}
}
@@ -341,7 +338,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file)
- checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple))
+ checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple))
}
}
@@ -369,11 +366,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val path = new Path(location.getCanonicalPath)
ParquetFileWriter.writeMetadataFile(
- sparkContext.hadoopConfiguration,
+ sqlContext.sparkContext.hadoopConfiguration,
path,
new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil)
- assertResult(read.parquet(path.toString).schema) {
+ assertResult(sqlContext.read.parquet(path.toString).schema) {
StructType(
StructField("a", BooleanType, nullable = false) ::
StructField("b", IntegerType, nullable = false) ::
@@ -383,6 +380,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
}
test("SPARK-6352 DirectParquetOutputCommitter") {
+ val clonedConf = new Configuration(configuration)
+
// Write to a parquet file and let it fail.
// _temporary should be missing if direct output committer works.
try {
@@ -397,16 +396,48 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val fs = path.getFileSystem(configuration)
assert(!fs.exists(path))
}
+ } finally {
+ // Hadoop 1 doesn't have `Configuration.unset`
+ configuration.clear()
+ clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue))
}
- finally {
- configuration.set("spark.sql.parquet.output.committer.class",
- "parquet.hadoop.ParquetOutputCommitter")
+ }
+
+ test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") {
+ withTempPath { dir =>
+ val clonedConf = new Configuration(configuration)
+
+ configuration.set(
+ SQLConf.OUTPUT_COMMITTER_CLASS, classOf[ParquetOutputCommitter].getCanonicalName)
+
+ configuration.set(
+ "spark.sql.parquet.output.committer.class",
+ classOf[BogusParquetOutputCommitter].getCanonicalName)
+
+ try {
+ val message = intercept[SparkException] {
+ sqlContext.range(0, 1).write.parquet(dir.getCanonicalPath)
+ }.getCause.getMessage
+ assert(message === "Intentional exception for testing purposes")
+ } finally {
+ // Hadoop 1 doesn't have `Configuration.unset`
+ configuration.clear()
+ clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue))
+ }
}
}
}
+class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)
+ extends ParquetOutputCommitter(outputPath, context) {
+
+ override def commitJob(jobContext: JobContext): Unit = {
+ sys.error("Intentional exception for testing purposes")
+ }
+}
+
class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
@@ -430,7 +461,7 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA
}
class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
index f231589e9674..01df189d1f3b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
@@ -14,22 +14,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.sql.parquet
import java.io.File
import java.math.BigInteger
-import java.sql.{Timestamp, Date}
+import java.sql.Timestamp
import scala.collection.mutable.ArrayBuffer
+import com.google.common.io.Files
import org.apache.hadoop.fs.Path
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.sources.PartitioningUtils._
import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec}
-import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext}
+import org.apache.spark.sql._
+import org.apache.spark.unsafe.types.UTF8String
// The data where the partitioning key exists only in the directory structure.
case class ParquetData(intField: Int, stringField: String)
@@ -38,33 +41,33 @@ case class ParquetData(intField: Int, stringField: String)
case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
- override val sqlContext: SQLContext = TestSQLContext
- import sqlContext._
+ override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
+ import sqlContext.sql
val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
test("column type inference") {
def check(raw: String, literal: Literal): Unit = {
- assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal)
+ assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal)
}
check("10", Literal.create(10, IntegerType))
check("1000000000000000", Literal.create(1000000000000000L, LongType))
- check("1.5", Literal.create(1.5, FloatType))
+ check("1.5", Literal.create(1.5, DoubleType))
check("hello", Literal.create("hello", StringType))
check(defaultPartitionName, Literal.create(null, NullType))
}
test("parse partition") {
def check(path: String, expected: Option[PartitionValues]): Unit = {
- assert(expected === parsePartition(new Path(path), defaultPartitionName))
+ assert(expected === parsePartition(new Path(path), defaultPartitionName, true))
}
def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = {
val message = intercept[T] {
- parsePartition(new Path(path), defaultPartitionName).get
+ parsePartition(new Path(path), defaultPartitionName, true).get
}.getMessage
assert(message.contains(expected))
@@ -82,13 +85,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
ArrayBuffer(
Literal.create(10, IntegerType),
Literal.create("hello", StringType),
- Literal.create(1.5, FloatType)))
+ Literal.create(1.5, DoubleType)))
})
check("file://path/a=10/b_hello/c=1.5", Some {
PartitionValues(
ArrayBuffer("c"),
- ArrayBuffer(Literal.create(1.5, FloatType)))
+ ArrayBuffer(Literal.create(1.5, DoubleType)))
})
check("file:///", None)
@@ -104,7 +107,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
test("parse partitions") {
def check(paths: Seq[String], spec: PartitionSpec): Unit = {
- assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec)
+ assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec)
}
check(Seq(
@@ -113,18 +116,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
StructType(Seq(
StructField("a", IntegerType),
StructField("b", StringType))),
- Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello"))))
+ Seq(Partition(InternalRow(10, UTF8String.fromString("hello")),
+ "hdfs://host:9000/path/a=10/b=hello"))))
check(Seq(
"hdfs://host:9000/path/a=10/b=20",
"hdfs://host:9000/path/a=10.5/b=hello"),
PartitionSpec(
StructType(Seq(
- StructField("a", FloatType),
+ StructField("a", DoubleType),
StructField("b", StringType))),
Seq(
- Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"),
- Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello"))))
+ Partition(InternalRow(10, UTF8String.fromString("20")),
+ "hdfs://host:9000/path/a=10/b=20"),
+ Partition(InternalRow(10.5, UTF8String.fromString("hello")),
+ "hdfs://host:9000/path/a=10.5/b=hello"))))
check(Seq(
"hdfs://host:9000/path/_temporary",
@@ -139,11 +145,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
"hdfs://host:9000/path/a=10.5/b=world/_temporary/path"),
PartitionSpec(
StructType(Seq(
- StructField("a", FloatType),
+ StructField("a", DoubleType),
StructField("b", StringType))),
Seq(
- Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"),
- Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello"))))
+ Partition(InternalRow(10, UTF8String.fromString("20")),
+ "hdfs://host:9000/path/a=10/b=20"),
+ Partition(InternalRow(10.5, UTF8String.fromString("hello")),
+ "hdfs://host:9000/path/a=10.5/b=hello"))))
check(Seq(
s"hdfs://host:9000/path/a=10/b=20",
@@ -153,19 +161,102 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
StructField("a", IntegerType),
StructField("b", StringType))),
Seq(
- Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"),
- Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"))))
+ Partition(InternalRow(10, UTF8String.fromString("20")),
+ s"hdfs://host:9000/path/a=10/b=20"),
+ Partition(InternalRow(null, UTF8String.fromString("hello")),
+ s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"))))
check(Seq(
s"hdfs://host:9000/path/a=10/b=$defaultPartitionName",
s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"),
PartitionSpec(
StructType(Seq(
- StructField("a", FloatType),
+ StructField("a", DoubleType),
StructField("b", StringType))),
Seq(
- Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"),
- Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"))))
+ Partition(InternalRow(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"),
+ Partition(InternalRow(10.5, null),
+ s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"))))
+
+ check(Seq(
+ s"hdfs://host:9000/path1",
+ s"hdfs://host:9000/path2"),
+ PartitionSpec.emptySpec)
+ }
+
+ test("parse partitions with type inference disabled") {
+ def check(paths: Seq[String], spec: PartitionSpec): Unit = {
+ assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec)
+ }
+
+ check(Seq(
+ "hdfs://host:9000/path/a=10/b=hello"),
+ PartitionSpec(
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", StringType))),
+ Seq(Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("hello")),
+ "hdfs://host:9000/path/a=10/b=hello"))))
+
+ check(Seq(
+ "hdfs://host:9000/path/a=10/b=20",
+ "hdfs://host:9000/path/a=10.5/b=hello"),
+ PartitionSpec(
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", StringType))),
+ Seq(
+ Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")),
+ "hdfs://host:9000/path/a=10/b=20"),
+ Partition(InternalRow(UTF8String.fromString("10.5"), UTF8String.fromString("hello")),
+ "hdfs://host:9000/path/a=10.5/b=hello"))))
+
+ check(Seq(
+ "hdfs://host:9000/path/_temporary",
+ "hdfs://host:9000/path/a=10/b=20",
+ "hdfs://host:9000/path/a=10.5/b=hello",
+ "hdfs://host:9000/path/a=10.5/_temporary",
+ "hdfs://host:9000/path/a=10.5/_TeMpOrArY",
+ "hdfs://host:9000/path/a=10.5/b=hello/_temporary",
+ "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY",
+ "hdfs://host:9000/path/_temporary/path",
+ "hdfs://host:9000/path/a=11/_temporary/path",
+ "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"),
+ PartitionSpec(
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", StringType))),
+ Seq(
+ Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")),
+ "hdfs://host:9000/path/a=10/b=20"),
+ Partition(InternalRow(UTF8String.fromString("10.5"), UTF8String.fromString("hello")),
+ "hdfs://host:9000/path/a=10.5/b=hello"))))
+
+ check(Seq(
+ s"hdfs://host:9000/path/a=10/b=20",
+ s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"),
+ PartitionSpec(
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", StringType))),
+ Seq(
+ Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")),
+ s"hdfs://host:9000/path/a=10/b=20"),
+ Partition(InternalRow(null, UTF8String.fromString("hello")),
+ s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"))))
+
+ check(Seq(
+ s"hdfs://host:9000/path/a=10/b=$defaultPartitionName",
+ s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"),
+ PartitionSpec(
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", StringType))),
+ Seq(
+ Partition(InternalRow(UTF8String.fromString("10"), null),
+ s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"),
+ Partition(InternalRow(UTF8String.fromString("10.5"), null),
+ s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"))))
check(Seq(
s"hdfs://host:9000/path1",
@@ -189,8 +280,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
// Introduce _temporary dir to the base dir the robustness of the schema discovery process.
new File(base.getCanonicalPath, "_temporary").mkdir()
- println("load the partitioned table")
- read.parquet(base.getCanonicalPath).registerTempTable("t")
+ sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
@@ -237,7 +327,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- read.parquet(base.getCanonicalPath).registerTempTable("t")
+ sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
@@ -285,7 +375,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath)
+ val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath)
parquetRelation.registerTempTable("t")
withTempTable("t") {
@@ -325,7 +415,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath)
+ val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath)
parquetRelation.registerTempTable("t")
withTempTable("t") {
@@ -357,7 +447,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
(1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"),
makePartitionDir(base, defaultPartitionName, "pi" -> 2))
- read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t")
+ sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
@@ -370,7 +460,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
test("SPARK-7749 Non-partitioned table should have empty partition spec") {
withTempPath { dir =>
(1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath)
- val queryExecution = read.parquet(dir.getCanonicalPath).queryExecution
+ val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution
queryExecution.analyzed.collectFirst {
case LogicalRelation(relation: ParquetRelation2) =>
assert(relation.partitionSpec === PartitionSpec.emptySpec)
@@ -384,7 +474,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
withTempPath { dir =>
val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s")
df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath)
- checkAnswer(read.parquet(dir.getCanonicalPath), df.collect())
+ checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect())
}
}
@@ -424,12 +514,28 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
}
val schema = StructType(partitionColumns :+ StructField(s"i", StringType))
- val df = createDataFrame(sparkContext.parallelize(row :: Nil), schema)
+ val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema)
withTempPath { dir =>
df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString)
val fields = schema.map(f => Column(f.name).cast(f.dataType))
- checkAnswer(read.load(dir.toString).select(fields: _*), row)
+ checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row)
+ }
+ }
+
+ test("SPARK-8037: Ignores files whose name starts with dot") {
+ withTempPath { dir =>
+ val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d")
+
+ df.write
+ .format("parquet")
+ .partitionBy("b", "c", "d")
+ .save(dir.getCanonicalPath)
+
+ Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store"))
+ Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar"))
+
+ checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index b98ba09ccfc2..be3b34d5b9b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -19,16 +19,15 @@ package org.apache.spark.sql.parquet
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.{SQLConf, QueryTest}
-import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{QueryTest, Row, SQLConf}
/**
* A test suite that tests various Parquet queries.
*/
class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
- val sqlContext = TestSQLContext
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+ import sqlContext.sql
test("simple select queries") {
withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
@@ -39,22 +38,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
- createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
- checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
+ checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
- catalog.unregisterTable(Seq("tmp"))
+ sqlContext.catalog.unregisterTable(Seq("tmp"))
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
- createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
- checkAnswer(table("t"), data.map(Row.fromTuple))
+ checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
}
- catalog.unregisterTable(Seq("tmp"))
+ sqlContext.catalog.unregisterTable(Seq("tmp"))
}
test("self-join") {
@@ -111,10 +110,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
List(Row("same", "run_5", 100)))
}
}
+
+ test("SPARK-6917 DecimalType should work with non-native types") {
+ val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i)))
+ val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
+ StructField("time", TimestampType, false)).toArray)
+ withTempPath { file =>
+ val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
+ df.write.parquet(file.getCanonicalPath)
+ val df2 = sqlContext.read.parquet(file.getCanonicalPath)
+ checkAnswer(df2, df.collect().toSeq)
+ }
+ }
}
class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
@@ -126,7 +137,7 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd
}
class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
index c964b6d98455..171a656f0e01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
@@ -20,15 +20,14 @@ package org.apache.spark.sql.parquet
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import org.scalatest.FunSuite
-import parquet.schema.MessageTypeParser
+import org.apache.parquet.schema.MessageTypeParser
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
-class ParquetSchemaSuite extends FunSuite with ParquetTest {
- val sqlContext = TestSQLContext
+class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
/**
* Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 516ba373f41d..eb15a1609f1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -33,8 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
private[sql] trait ParquetTest extends SQLTestUtils {
- import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder}
- import sqlContext.sparkContext
/**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
@@ -44,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
- sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath)
+ sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
@@ -75,7 +73,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
- data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
+ sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index d2d1011b8e91..a71088430bfd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -26,18 +26,20 @@ import org.apache.spark.util.Utils
class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
- import caseInsensitiveContext._
+ import caseInsensitiveContext.sql
+
+ private lazy val sparkContext = caseInsensitiveContext.sparkContext
var path: File = null
override def beforeAll(): Unit = {
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
- read.json(rdd).registerTempTable("jt")
+ caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
}
override def afterAll(): Unit = {
- dropTempTable("jt")
+ caseInsensitiveContext.dropTempTable("jt")
}
after {
@@ -59,7 +61,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a, b FROM jsonTable"),
sql("SELECT a, b FROM jt").collect())
- dropTempTable("jsonTable")
+ caseInsensitiveContext.dropTempTable("jsonTable")
}
test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") {
@@ -129,7 +131,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT * FROM jsonTable"),
sql("SELECT a * 4 FROM jt").collect())
- dropTempTable("jsonTable")
+ caseInsensitiveContext.dropTempTable("jsonTable")
// Explicitly delete the data.
if (path.exists()) Utils.deleteRecursively(path)
@@ -147,7 +149,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT * FROM jsonTable"),
sql("SELECT b FROM jt").collect())
- dropTempTable("jsonTable")
+ caseInsensitiveContext.dropTempTable("jsonTable")
}
test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index 5c3467158a01..5fc53f701299 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class DDLScanSource extends RelationProvider {
override def createRelation(
@@ -56,26 +58,28 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
)
))
+ override def needConversion: Boolean = false
override def buildScan(): RDD[Row] = {
- sqlContext.sparkContext.parallelize(from to to).map(e => Row(s"people$e", e * 2))
+ sqlContext.sparkContext.parallelize(from to to).map { e =>
+ InternalRow(UTF8String.fromString(s"people$e"), e * 2)
+ }
}
}
class DDLTestSuite extends DataSourceTest {
- import caseInsensitiveContext._
before {
- sql(
- """
- |CREATE TEMPORARY TABLE ddlPeople
- |USING org.apache.spark.sql.sources.DDLScanSource
- |OPTIONS (
- | From '1',
- | To '10',
- | Table 'test1'
- |)
- """.stripMargin)
+ caseInsensitiveContext.sql(
+ """
+ |CREATE TEMPORARY TABLE ddlPeople
+ |USING org.apache.spark.sql.sources.DDLScanSource
+ |OPTIONS (
+ | From '1',
+ | To '10',
+ | Table 'test1'
+ |)
+ """.stripMargin)
}
sqlTest(
@@ -100,7 +104,8 @@ class DDLTestSuite extends DataSourceTest {
))
test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
- val attributes = sql("describe ddlPeople").queryExecution.executedPlan.output
+ val attributes = caseInsensitiveContext.sql("describe ddlPeople")
+ .queryExecution.executedPlan.output
assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
assert(attributes.map(_.dataType).toSet === Set(StringType))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 24ed665c67d2..3f77960d0924 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -17,14 +17,18 @@
package org.apache.spark.sql.sources
+import org.scalatest.BeforeAndAfter
+
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.test.TestSQLContext
-import org.scalatest.BeforeAndAfter
+
abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
// We want to test some edge cases.
- implicit val caseInsensitiveContext = new SQLContext(TestSQLContext.sparkContext)
+ protected implicit lazy val caseInsensitiveContext = {
+ val ctx = new SQLContext(TestSQLContext.sparkContext)
+ ctx.setConf(SQLConf.CASE_SENSITIVE, "false")
+ ctx
+ }
- caseInsensitiveContext.setConf(SQLConf.CASE_SENSITIVE, "false")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index db94b1f3e892..81b3a0f0c5b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -97,7 +97,7 @@ object FiltersPushed {
class FilteredScanSuite extends DataSourceTest {
- import caseInsensitiveContext._
+ import caseInsensitiveContext.sql
before {
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 6f375ef36237..0b7c46c482c8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -26,14 +26,16 @@ import org.apache.spark.util.Utils
class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
- import caseInsensitiveContext._
+ import caseInsensitiveContext.sql
+
+ private lazy val sparkContext = caseInsensitiveContext.sparkContext
var path: File = null
override def beforeAll: Unit = {
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
- read.json(rdd).registerTempTable("jt")
+ caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
sql(
s"""
|CREATE TEMPORARY TABLE jsonTable (a int, b string)
@@ -45,8 +47,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}
override def afterAll: Unit = {
- dropTempTable("jsonTable")
- dropTempTable("jt")
+ caseInsensitiveContext.dropTempTable("jsonTable")
+ caseInsensitiveContext.dropTempTable("jt")
Utils.deleteRecursively(path)
}
@@ -109,7 +111,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
// Writing the table to less part files.
val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5)
- read.json(rdd1).registerTempTable("jt1")
+ caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1")
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1
@@ -121,7 +123,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
// Writing the table to more part files.
val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10)
- read.json(rdd2).registerTempTable("jt2")
+ caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2")
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2
@@ -140,8 +142,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
(1 to 10).map(i => Row(i * 10, s"str$i"))
)
- dropTempTable("jt1")
- dropTempTable("jt2")
+ caseInsensitiveContext.dropTempTable("jt1")
+ caseInsensitiveContext.dropTempTable("jt2")
}
test("INSERT INTO not supported for JSONRelation for now") {
@@ -154,13 +156,14 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}
test("save directly to the path of a JSON table") {
- table("jt").selectExpr("a * 5 as a", "b").write.mode(SaveMode.Overwrite).json(path.toString)
+ caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b")
+ .write.mode(SaveMode.Overwrite).json(path.toString)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i * 5, s"str$i"))
)
- table("jt").write.mode(SaveMode.Overwrite).json(path.toString)
+ caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i, s"str$i"))
@@ -181,7 +184,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
test("Caching") {
// Cached Query Execution
- cacheTable("jsonTable")
+ caseInsensitiveContext.cacheTable("jsonTable")
assertCached(sql("SELECT * FROM jsonTable"))
checkAnswer(
sql("SELECT * FROM jsonTable"),
@@ -220,7 +223,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a * 2, b FROM jt").collect())
// Verify uncaching
- uncacheTable("jsonTable")
+ caseInsensitiveContext.uncacheTable("jsonTable")
assertCached(sql("SELECT * FROM jsonTable"), 0)
}
@@ -251,6 +254,6 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
"It is not allowed to insert into a table that is not an InsertableRelation."
)
- dropTempTable("oneToTen")
+ caseInsensitiveContext.dropTempTable("oneToTen")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index c2bc52e2120c..257526feab94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -52,10 +52,9 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
}
class PrunedScanSuite extends DataSourceTest {
- import caseInsensitiveContext._
before {
- sql(
+ caseInsensitiveContext.sql(
"""
|CREATE TEMPORARY TABLE oneToTenPruned
|USING org.apache.spark.sql.sources.PrunedScanSource
@@ -115,7 +114,7 @@ class PrunedScanSuite extends DataSourceTest {
def testPruning(sqlString: String, expectedColumns: String*): Unit = {
test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
- val queryExecution = sql(sqlString).queryExecution
+ val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution
val rawPlan = queryExecution.executedPlan.collect {
case p: execution.PhysicalRDD => p
} match {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
index 8331a14c9295..296b0d6f74a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.sources
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class ResolvedDataSourceSuite extends FunSuite {
+class ResolvedDataSourceSuite extends SparkFunSuite {
test("builtin sources") {
assert(ResolvedDataSource.lookupDataSource("jdbc") ===
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
index 274c652dd14d..b032515a9d28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
@@ -27,7 +27,9 @@ import org.apache.spark.util.Utils
class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
- import caseInsensitiveContext._
+ import caseInsensitiveContext.sql
+
+ private lazy val sparkContext = caseInsensitiveContext.sparkContext
var originalDefaultSource: String = null
@@ -36,60 +38,63 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
var df: DataFrame = null
override def beforeAll(): Unit = {
- originalDefaultSource = conf.defaultDataSourceName
+ originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName
path = Utils.createTempDir()
path.delete()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
- df = read.json(rdd)
+ df = caseInsensitiveContext.read.json(rdd)
df.registerTempTable("jsonTable")
}
override def afterAll(): Unit = {
- conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
}
after {
- conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
Utils.deleteRecursively(path)
}
def checkLoad(): Unit = {
- conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
- checkAnswer(read.load(path.toString), df.collect())
+ caseInsensitiveContext.conf.setConf(
+ SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ checkAnswer(caseInsensitiveContext.read.load(path.toString), df.collect())
// Test if we can pick up the data source name passed in load.
- conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
- checkAnswer(read.format("json").load(path.toString), df.collect())
- checkAnswer(read.format("json").load(path.toString), df.collect())
+ caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect())
+ checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect())
val schema = StructType(StructField("b", StringType, true) :: Nil)
checkAnswer(
- read.format("json").schema(schema).load(path.toString),
+ caseInsensitiveContext.read.format("json").schema(schema).load(path.toString),
sql("SELECT b FROM jsonTable").collect())
}
test("save with path and load") {
- conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ caseInsensitiveContext.conf.setConf(
+ SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
df.write.save(path.toString)
checkLoad()
}
test("save with string mode and path, and load") {
- conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ caseInsensitiveContext.conf.setConf(
+ SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
path.createNewFile()
df.write.mode("overwrite").save(path.toString)
checkLoad()
}
test("save with path and datasource, and load") {
- conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
df.write.json(path.toString)
checkLoad()
}
test("save with data source and options, and load") {
- conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
df.write.mode(SaveMode.ErrorIfExists).json(path.toString)
checkLoad()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 77af04a49174..48875773224c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -19,9 +19,13 @@ package org.apache.spark.sql.sources
import java.sql.{Timestamp, Date}
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class DefaultSource extends SimpleScanSource
@@ -47,6 +51,10 @@ class AllDataTypesScanSource extends SchemaRelationProvider {
sqlContext: SQLContext,
parameters: Map[String, String],
schema: StructType): BaseRelation = {
+ // Check that weird parameters are passed correctly.
+ parameters("option_with_underscores")
+ parameters("option.with.dots")
+
AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
}
}
@@ -60,10 +68,12 @@ case class AllDataTypesScan(
override def schema: StructType = userSpecifiedSchema
+ override def needConversion: Boolean = false
+
override def buildScan(): RDD[Row] = {
sqlContext.sparkContext.parallelize(from to to).map { i =>
- Row(
- s"str_$i",
+ InternalRow(
+ UTF8String.fromString(s"str_$i"),
s"str_$i".getBytes(),
i % 2 == 0,
i.toByte,
@@ -72,25 +82,26 @@ case class AllDataTypesScan(
i.toLong,
i.toFloat,
i.toDouble,
- new java.math.BigDecimal(i),
- new java.math.BigDecimal(i),
- new Date(1970, 1, 1),
- new Timestamp(20000 + i),
- s"varchar_$i",
+ Decimal(new java.math.BigDecimal(i)),
+ Decimal(new java.math.BigDecimal(i)),
+ DateUtils.fromJavaDate(new Date(1970, 1, 1)),
+ DateUtils.fromJavaTimestamp(new Timestamp(20000 + i)),
+ UTF8String.fromString(s"varchar_$i"),
Seq(i, i + 1),
- Seq(Map(s"str_$i" -> Row(i.toLong))),
+ Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
Map(i -> i.toString),
- Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
+ Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
Row(i, i.toString),
- Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1)))))
+ Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
+ InternalRow(Seq(DateUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
}
}
}
class TableScanSuite extends DataSourceTest {
- import caseInsensitiveContext._
+ import caseInsensitiveContext.sql
- var tableWithSchemaExpected = (1 to 10).map { i =>
+ private lazy val tableWithSchemaExpected = (1 to 10).map { i =>
Row(
s"str_$i",
s"str_$i",
@@ -121,7 +132,9 @@ class TableScanSuite extends DataSourceTest {
|USING org.apache.spark.sql.sources.SimpleScanSource
|OPTIONS (
| From '1',
- | To '10'
+ | To '10',
+ | option_with_underscores 'someval',
+ | option.with.dots 'someval'
|)
""".stripMargin)
@@ -152,7 +165,9 @@ class TableScanSuite extends DataSourceTest {
|USING org.apache.spark.sql.sources.AllDataTypesScanSource
|OPTIONS (
| From '1',
- | To '10'
+ | To '10',
+ | option_with_underscores 'someval',
+ | option.with.dots 'someval'
|)
""".stripMargin)
}
@@ -215,7 +230,7 @@ class TableScanSuite extends DataSourceTest {
Nil
)
- assert(expectedSchema == table("tableWithSchema").schema)
+ assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema)
checkAnswer(
sql(
@@ -270,7 +285,7 @@ class TableScanSuite extends DataSourceTest {
test("Caching") {
// Cached Query Execution
- cacheTable("oneToTen")
+ caseInsensitiveContext.cacheTable("oneToTen")
assertCached(sql("SELECT * FROM oneToTen"))
checkAnswer(
sql("SELECT * FROM oneToTen"),
@@ -297,7 +312,7 @@ class TableScanSuite extends DataSourceTest {
(2 to 10).map(i => Row(i, i - 1)).toSeq)
// Verify uncaching
- uncacheTable("oneToTen")
+ caseInsensitiveContext.uncacheTable("oneToTen")
assertCached(sql("SELECT * FROM oneToTen"), 0)
}
@@ -354,7 +369,9 @@ class TableScanSuite extends DataSourceTest {
|USING org.apache.spark.sql.sources.AllDataTypesScanSource
|OPTIONS (
| from '1',
- | to '10'
+ | to '10',
+ | option_with_underscores 'someval',
+ | option.with.dots 'someval'
|)
""".stripMargin)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 17a8b0cca09d..ac4a00a6f3da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -25,11 +25,9 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
trait SQLTestUtils {
- val sqlContext: SQLContext
+ def sqlContext: SQLContext
- import sqlContext.{conf, sparkContext}
-
- protected def configuration = sparkContext.hadoopConfiguration
+ protected def configuration = sqlContext.sparkContext.hadoopConfiguration
/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
@@ -39,12 +37,12 @@ trait SQLTestUtils {
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(conf.getConf(key)).toOption)
- (keys, values).zipped.foreach(conf.setConf)
+ val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption)
+ (keys, values).zipped.foreach(sqlContext.conf.setConf)
try f finally {
keys.zip(currentValues).foreach {
- case (key, Some(value)) => conf.setConf(key, value)
- case (key, None) => conf.unsetConf(key)
+ case (key, Some(value)) => sqlContext.conf.setConf(key, value)
+ case (key, None) => sqlContext.conf.unsetConf(key)
}
}
}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 437f697d25bf..73e6ccdb1eaf 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
spark-hive_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ com.google.guavaguava
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index 3458b04bfba0..c9da25253e13 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -17,23 +17,24 @@
package org.apache.spark.sql.hive.thriftserver
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService}
import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}
-import org.apache.spark.sql.SQLConf
-import org.apache.spark.{SparkContext, SparkConf, Logging}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart}
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
-import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerApplicationEnd, SparkListener}
import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkContext}
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
/**
* The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a
@@ -51,6 +52,7 @@ object HiveThriftServer2 extends Logging {
@DeveloperApi
def startWithContext(sqlContext: HiveContext): Unit = {
val server = new HiveThriftServer2(sqlContext)
+ sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion)
server.init(sqlContext.hiveconf)
server.start()
listener = new HiveThriftServer2Listener(server, sqlContext.conf)
diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
similarity index 52%
rename from sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index b9d4f1c58c98..e071103df925 100644
--- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -17,78 +17,55 @@
package org.apache.spark.sql.hive.thriftserver
+import java.security.PrivilegedExceptionAction
import java.sql.{Date, Timestamp}
-import java.util.concurrent.Executors
-import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, UUID}
-
-import org.apache.commons.logging.Log
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hive.service.cli.thrift.TProtocolVersion
-import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager
+import java.util.concurrent.RejectedExecutionException
+import java.util.{Map => JMap, UUID}
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, Map => SMap}
+import scala.util.control.NonFatal
+import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.FieldSchema
-import org.apache.hadoop.security.UserGroupInformation
import org.apache.hive.service.cli._
+import org.apache.hadoop.hive.ql.metadata.Hive
+import org.apache.hadoop.hive.ql.metadata.HiveException
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.shims.ShimLoader
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.hive.service.cli.operation.ExecuteStatementOperation
-import org.apache.hive.service.cli.session.{SessionManager, HiveSession}
+import org.apache.hive.service.cli.session.HiveSession
-import org.apache.spark.{SparkContext, Logging}
-import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf}
+import org.apache.spark.Logging
import org.apache.spark.sql.execution.SetCommand
-import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf}
-/**
- * A compatibility layer for interacting with Hive version 0.13.1.
- */
-private[thriftserver] object HiveThriftServerShim {
- val version = "0.13.1"
-
- def setServerUserName(
- sparkServiceUGI: UserGroupInformation,
- sparkCliService:SparkSQLCLIService) = {
- setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI)
- }
-}
-
-private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext)
- extends AbstractSparkSQLDriver(_context) {
- override def getResults(res: JList[_]): Boolean = {
- if (hiveResponse == null) {
- false
- } else {
- res.asInstanceOf[JArrayList[String]].addAll(hiveResponse)
- hiveResponse = null
- true
- }
- }
-}
private[hive] class SparkExecuteStatementOperation(
parentSession: HiveSession,
statement: String,
confOverlay: JMap[String, String],
- runInBackground: Boolean = true)(
- hiveContext: HiveContext,
- sessionToActivePool: SMap[SessionHandle, String])
- // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution
- extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging {
+ runInBackground: Boolean = true)
+ (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String])
+ extends ExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)
+ with Logging {
private var result: DataFrame = _
private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _
+ private var statementId: String = _
def close(): Unit = {
// RDDs will be cleaned automatically upon garbage collection.
- logDebug("CLOSING")
+ hiveContext.sparkContext.clearJobGroup()
+ logDebug(s"CLOSING $statementId")
+ cleanup(OperationState.CLOSED)
}
- def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) {
+ def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) {
dataTypes(ordinal) match {
case StringType =>
to += from.getString(ordinal)
@@ -149,10 +126,10 @@ private[hive] class SparkExecuteStatementOperation(
}
def getResultSetSchema: TableSchema = {
- logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}")
- if (result.queryExecution.analyzed.output.size == 0) {
+ if (result == null || result.queryExecution.analyzed.output.size == 0) {
new TableSchema(new FieldSchema("Result", "string", "") :: Nil)
} else {
+ logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}")
val schema = result.queryExecution.analyzed.output.map { attr =>
new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
}
@@ -160,9 +137,73 @@ private[hive] class SparkExecuteStatementOperation(
}
}
- def run(): Unit = {
- val statementId = UUID.randomUUID().toString
- logInfo(s"Running query '$statement'")
+ override def run(): Unit = {
+ setState(OperationState.PENDING)
+ setHasResultSet(true) // avoid no resultset for async run
+
+ if (!runInBackground) {
+ runInternal()
+ } else {
+ val parentSessionState = SessionState.get()
+ val hiveConf = getConfigForOperation()
+ val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf)
+ val sessionHive = getCurrentHive()
+ val currentSqlSession = hiveContext.currentSession
+
+ // Runnable impl to call runInternal asynchronously,
+ // from a different thread
+ val backgroundOperation = new Runnable() {
+
+ override def run(): Unit = {
+ val doAsAction = new PrivilegedExceptionAction[Object]() {
+ override def run(): Object = {
+
+ // User information is part of the metastore client member in Hive
+ hiveContext.setSession(currentSqlSession)
+ Hive.set(sessionHive)
+ SessionState.setCurrentSessionState(parentSessionState)
+ try {
+ runInternal()
+ } catch {
+ case e: HiveSQLException =>
+ setOperationException(e)
+ log.error("Error running hive query: ", e)
+ }
+ return null
+ }
+ }
+
+ try {
+ ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction)
+ } catch {
+ case e: Exception =>
+ setOperationException(new HiveSQLException(e))
+ logError("Error running hive query as user : " +
+ sparkServiceUGI.getShortUserName(), e)
+ }
+ }
+ }
+ try {
+ // This submit blocks if no background threads are available to run this operation
+ val backgroundHandle =
+ getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation)
+ setBackgroundHandle(backgroundHandle)
+ } catch {
+ case rejected: RejectedExecutionException =>
+ setState(OperationState.ERROR)
+ throw new HiveSQLException("The background threadpool cannot accept" +
+ " new task for execution, please retry the operation", rejected)
+ case NonFatal(e) =>
+ logError(s"Error executing query in background", e)
+ setState(OperationState.ERROR)
+ throw e
+ }
+ }
+ }
+
+ private def runInternal(): Unit = {
+ statementId = UUID.randomUUID().toString
+ logInfo(s"Running query '$statement' with $statementId")
setState(OperationState.RUNNING)
HiveThriftServer2.listener.onStatementStart(
statementId,
@@ -194,63 +235,82 @@ private[hive] class SparkExecuteStatementOperation(
}
}
dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray
- setHasResultSet(true)
} catch {
+ case e: HiveSQLException =>
+ if (getStatus().getState() == OperationState.CANCELED) {
+ return
+ } else {
+ setState(OperationState.ERROR);
+ throw e
+ }
// Actually do need to catch Throwable as some failures don't inherit from Exception and
// HiveServer will silently swallow them.
case e: Throwable =>
+ val currentState = getStatus().getState()
+ logError(s"Error executing query, currentState $currentState, ", e)
setState(OperationState.ERROR)
HiveThriftServer2.listener.onStatementError(
statementId, e.getMessage, e.getStackTraceString)
- logError("Error executing query:", e)
throw new HiveSQLException(e.toString)
}
setState(OperationState.FINISHED)
HiveThriftServer2.listener.onStatementFinish(statementId)
}
-}
-
-private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
- extends SessionManager
- with ReflectedCompositeService {
-
- private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
-
- override def init(hiveConf: HiveConf) {
- setSuperField(this, "hiveConf", hiveConf)
-
- val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS)
- setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize))
- getAncestorField[Log](this, 3, "LOG").info(
- s"HiveServer2: Async execution pool size $backgroundPoolSize")
- setSuperField(this, "operationManager", sparkSqlOperationManager)
- addService(sparkSqlOperationManager)
-
- initCompositeService(hiveConf)
+ override def cancel(): Unit = {
+ logInfo(s"Cancel '$statement' with $statementId")
+ if (statementId != null) {
+ hiveContext.sparkContext.cancelJobGroup(statementId)
+ }
+ cleanup(OperationState.CANCELED)
}
- override def openSession(
- protocol: TProtocolVersion,
- username: String,
- passwd: String,
- sessionConf: java.util.Map[String, String],
- withImpersonation: Boolean,
- delegationToken: String): SessionHandle = {
- hiveContext.openSession()
- val sessionHandle = super.openSession(
- protocol, username, passwd, sessionConf, withImpersonation, delegationToken)
- val session = super.getSession(sessionHandle)
- HiveThriftServer2.listener.onSessionCreated(
- session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername)
- sessionHandle
+ private def cleanup(state: OperationState) {
+ setState(state)
+ if (runInBackground) {
+ val backgroundHandle = getBackgroundHandle()
+ if (backgroundHandle != null) {
+ backgroundHandle.cancel(true)
+ }
+ }
}
- override def closeSession(sessionHandle: SessionHandle) {
- HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString)
- super.closeSession(sessionHandle)
- sparkSqlOperationManager.sessionToActivePool -= sessionHandle
+ /**
+ * If there are query specific settings to overlay, then create a copy of config
+ * There are two cases we need to clone the session config that's being passed to hive driver
+ * 1. Async query -
+ * If the client changes a config setting, that shouldn't reflect in the execution
+ * already underway
+ * 2. confOverlay -
+ * The query specific settings should only be applied to the query config and not session
+ * @return new configuration
+ * @throws HiveSQLException
+ */
+ private def getConfigForOperation(): HiveConf = {
+ var sqlOperationConf = getParentSession().getHiveConf()
+ if (!getConfOverlay().isEmpty() || runInBackground) {
+ // clone the partent session config for this query
+ sqlOperationConf = new HiveConf(sqlOperationConf)
+
+ // apply overlay query specific settings, if any
+ getConfOverlay().foreach { case (k, v) =>
+ try {
+ sqlOperationConf.verifyAndSet(k, v)
+ } catch {
+ case e: IllegalArgumentException =>
+ throw new HiveSQLException("Error applying statement specific settings", e)
+ }
+ }
+ }
+ return sqlOperationConf
+ }
- hiveContext.detachSession()
+ private def getCurrentHive(): Hive = {
+ try {
+ return Hive.get()
+ } catch {
+ case e: HiveException =>
+ throw new HiveSQLException("Failed to get current Hive object", e);
+ }
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
index 14f6f658d9b7..039cfa40d26b 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -32,12 +32,12 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils}
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.exec.Utilities
-import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor}
+import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory}
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.thrift.transport.TSocket
import org.apache.spark.Logging
-import org.apache.spark.sql.hive.{HiveContext, HiveShim}
+import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.util.Utils
private[hive] object SparkSQLCLIDriver {
@@ -267,7 +267,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
} else {
var ret = 0
val hconf = conf.asInstanceOf[HiveConf]
- val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf)
+ val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf)
if (proc != null) {
if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] ||
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
index 499e077d7294..41f647d5f8c5 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
@@ -21,8 +21,6 @@ import java.io.IOException
import java.util.{List => JList}
import javax.security.auth.login.LoginException
-import scala.collection.JavaConversions._
-
import org.apache.commons.logging.Log
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.shims.ShimLoader
@@ -34,7 +32,8 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
-import org.apache.spark.util.Utils
+
+import scala.collection.JavaConversions._
private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
extends CLIService
@@ -52,7 +51,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
try {
HiveAuthFactory.loginFromKeytab(hiveConf)
sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf)
- HiveThriftServerShim.setServerUserName(sparkServiceUGI, this)
+ setSuperField(this, "serviceUGI", sparkServiceUGI)
} catch {
case e @ (_: IOException | _: LoginException) =>
throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
similarity index 86%
rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala
rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
index 48ac9062af96..77272aecf283 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hive.thriftserver
-import scala.collection.JavaConversions._
+import java.util.{ArrayList => JArrayList, List => JList}
import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
@@ -27,8 +27,12 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse
import org.apache.spark.Logging
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
-private[hive] abstract class AbstractSparkSQLDriver(
- val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging {
+import scala.collection.JavaConversions._
+
+private[hive] class SparkSQLDriver(
+ val context: HiveContext = SparkSQLEnv.hiveContext)
+ extends Driver
+ with Logging {
private[hive] var tableSchema: Schema = _
private[hive] var hiveResponse: Seq[String] = _
@@ -71,6 +75,16 @@ private[hive] abstract class AbstractSparkSQLDriver(
0
}
+ override def getResults(res: JList[_]): Boolean = {
+ if (hiveResponse == null) {
+ false
+ } else {
+ res.asInstanceOf[JArrayList[String]].addAll(hiveResponse)
+ hiveResponse = null
+ true
+ }
+ }
+
override def getSchema: Schema = tableSchema
override def destroy() {
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
index 7c0c505e2d61..79eda1f5123b 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -22,7 +22,7 @@ import java.io.PrintStream
import scala.collection.JavaConversions._
import org.apache.spark.scheduler.StatsReportListener
-import org.apache.spark.sql.hive.{HiveShim, HiveContext}
+import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.util.Utils
@@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging {
hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8"))
hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8"))
- hiveContext.setConf("spark.sql.hive.version", HiveShim.version)
+ hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion)
if (log.isDebugEnabled) {
hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) =>
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
new file mode 100644
index 000000000000..2d5ee6800228
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.util.concurrent.Executors
+
+import org.apache.commons.logging.Log
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hive.service.cli.SessionHandle
+import org.apache.hive.service.cli.session.SessionManager
+import org.apache.hive.service.cli.thrift.TProtocolVersion
+
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager
+
+
+private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
+ extends SessionManager
+ with ReflectedCompositeService {
+
+ private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
+
+ override def init(hiveConf: HiveConf) {
+ setSuperField(this, "hiveConf", hiveConf)
+
+ val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS)
+ setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize))
+ getAncestorField[Log](this, 3, "LOG").info(
+ s"HiveServer2: Async execution pool size $backgroundPoolSize")
+
+ setSuperField(this, "operationManager", sparkSqlOperationManager)
+ addService(sparkSqlOperationManager)
+
+ initCompositeService(hiveConf)
+ }
+
+ override def openSession(
+ protocol: TProtocolVersion,
+ username: String,
+ passwd: String,
+ sessionConf: java.util.Map[String, String],
+ withImpersonation: Boolean,
+ delegationToken: String): SessionHandle = {
+ hiveContext.openSession()
+ val sessionHandle = super.openSession(
+ protocol, username, passwd, sessionConf, withImpersonation, delegationToken)
+ val session = super.getSession(sessionHandle)
+ HiveThriftServer2.listener.onSessionCreated(
+ session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername)
+ sessionHandle
+ }
+
+ override def closeSession(sessionHandle: SessionHandle) {
+ HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString)
+ super.closeSession(sessionHandle)
+ sparkSqlOperationManager.sessionToActivePool -= sessionHandle
+
+ hiveContext.detachSession()
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index 9c0bf02391e0..c8031ed0f343 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -44,9 +44,12 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
confOverlay: JMap[String, String],
async: Boolean): ExecuteStatementOperation = synchronized {
- val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)(
- hiveContext, sessionToActivePool)
+ val runInBackground = async && hiveContext.hiveThriftServerAsync
+ val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
+ runInBackground)(hiveContext, sessionToActivePool)
handleToOperation.put(operation.getHandle, operation)
+ logDebug(s"Created Operation for $statement with session=$parentSession, " +
+ s"runInBackground=$runInBackground")
operation
}
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index cc07db827d35..13b0c5951ddd 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -25,16 +25,16 @@ import scala.concurrent.{Await, Promise}
import scala.sys.process.{Process, ProcessLogger}
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.util.Utils
/**
* A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary
* Hive metastore and warehouse.
*/
-class CliSuite extends FunSuite with BeforeAndAfter with Logging {
+class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
val warehousePath = Utils.createTempDir()
val metastorePath = Utils.createTempDir()
@@ -133,7 +133,7 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging {
}
test("Single command with -e") {
- runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK")
+ runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK")
}
test("Single command with --database") {
@@ -165,7 +165,7 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging {
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
- runCliWithin(1.minute, Seq("--jars", s"$jarFile"))(
+ runCliWithin(3.minute, Seq("--jars", s"$jarFile"))(
"""CREATE TABLE t1(key string, val string)
|ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe';
""".stripMargin
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 1fadea97fd07..178bd1f5cb16 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -19,14 +19,18 @@ package org.apache.spark.sql.hive.thriftserver
import java.io.File
import java.net.URL
-import java.sql.{Date, DriverManager, Statement}
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, DriverManager, SQLException, Statement}
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
-import scala.concurrent.{Await, Promise}
+import scala.concurrent.{Await, Promise, future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.sys.process.{Process, ProcessLogger}
import scala.util.{Random, Try}
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.Files
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
import org.apache.hive.service.auth.PlainSaslHelper
@@ -35,10 +39,10 @@ import org.apache.hive.service.cli.thrift.TCLIService.Client
import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.transport.TSocket
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.Logging
-import org.apache.spark.sql.hive.HiveShim
+import org.apache.spark.{Logging, SparkFunSuite}
+import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.util.Utils
object TestData {
@@ -54,7 +58,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
override def mode: ServerMode.Value = ServerMode.binary
private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = {
- // Transport creation logics below mimics HiveConnection.createBinaryTransport
+ // Transport creation logic below mimics HiveConnection.createBinaryTransport
val rawTransport = new TSocket("localhost", serverPort)
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
@@ -109,7 +113,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
withJdbcStatement { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
- assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
+ assert(resultSet.getString(1) ===
+ s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}")
}
}
@@ -335,6 +340,42 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
}
)
}
+
+ test("test jdbc cancel") {
+ withJdbcStatement { statement =>
+ val queries = Seq(
+ "DROP TABLE IF EXISTS test_map",
+ "CREATE TABLE test_map(key INT, value STRING)",
+ s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map")
+
+ queries.foreach(statement.execute)
+
+ val largeJoin = "SELECT COUNT(*) FROM test_map " +
+ List.fill(10)("join test_map").mkString(" ")
+ val f = future { Thread.sleep(100); statement.cancel(); }
+ val e = intercept[SQLException] {
+ statement.executeQuery(largeJoin)
+ }
+ assert(e.getMessage contains "cancelled")
+ Await.result(f, 3.minute)
+
+ // cancel is a noop
+ statement.executeQuery("SET spark.sql.hive.thriftServer.async=false")
+ val sf = future { Thread.sleep(100); statement.cancel(); }
+ val smallJoin = "SELECT COUNT(*) FROM test_map " +
+ List.fill(4)("join test_map").mkString(" ")
+ val rs1 = statement.executeQuery(smallJoin)
+ Await.result(sf, 3.minute)
+ rs1.next()
+ assert(rs1.getInt(1) === math.pow(5, 5))
+ rs1.close()
+
+ val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map")
+ rs2.next()
+ assert(rs2.getInt(1) === 5)
+ rs2.close()
+ }
+ }
}
class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
@@ -363,7 +404,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
withJdbcStatement { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
- assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
+ assert(resultSet.getString(1) ===
+ s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}")
}
}
}
@@ -391,10 +433,10 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
val statements = connections.map(_.createStatement())
try {
- statements.zip(fs).map { case (s, f) => f(s) }
+ statements.zip(fs).foreach { case (s, f) => f(s) }
} finally {
- statements.map(_.close())
- connections.map(_.close())
+ statements.foreach(_.close())
+ connections.foreach(_.close())
}
}
@@ -403,7 +445,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
}
}
-abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging {
+abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAll with Logging {
def mode: ServerMode.Value
private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")
@@ -433,15 +475,33 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit
ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT
}
+ val driverClassPath = {
+ // Writes a temporary log4j.properties and prepend it to driver classpath, so that it
+ // overrides all other potential log4j configurations contained in other dependency jar files.
+ val tempLog4jConf = Utils.createTempDir().getCanonicalPath
+
+ Files.write(
+ """log4j.rootCategory=INFO, console
+ |log4j.appender.console=org.apache.log4j.ConsoleAppender
+ |log4j.appender.console.target=System.err
+ |log4j.appender.console.layout=org.apache.log4j.PatternLayout
+ |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+ """.stripMargin,
+ new File(s"$tempLog4jConf/log4j.properties"),
+ UTF_8)
+
+ tempLog4jConf + File.pathSeparator + sys.props("java.class.path")
+ }
+
s"""$startScript
| --master local
- | --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode
| --hiveconf $portConf=$port
- | --driver-class-path ${sys.props("java.class.path")}
+ | --driver-class-path $driverClassPath
+ | --driver-java-options -Dlog4j.debug
| --conf spark.ui.enabled=false
""".stripMargin.split("\\s+").toSeq
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
index 4c9fab7ef613..806240e6de45 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
@@ -22,12 +22,13 @@ import scala.util.Random
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.openqa.selenium.WebDriver
import org.openqa.selenium.htmlunit.HtmlUnitDriver
+import org.scalatest.{BeforeAndAfterAll, Matchers}
import org.scalatest.concurrent.Eventually._
import org.scalatest.selenium.WebBrowser
import org.scalatest.time.SpanSugar._
-import org.scalatest.{BeforeAndAfterAll, Matchers}
import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.ui.SparkUICssErrorHandler
class UISeleniumSuite
extends HiveThriftJdbcTest
@@ -40,7 +41,9 @@ class UISeleniumSuite
override def mode: ServerMode.Value = ServerMode.binary
override def beforeAll(): Unit = {
- webDriver = new HtmlUnitDriver
+ webDriver = new HtmlUnitDriver {
+ getWebClient.setCssErrorHandler(new SparkUICssErrorHandler)
+ }
super.beforeAll()
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 0b1917a39290..82c0b494598a 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -23,7 +23,6 @@ import java.util.{Locale, TimeZone}
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.hive.HiveShim
import org.apache.spark.sql.hive.test.TestHive
/**
@@ -253,8 +252,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"load_dyn_part14.*", // These work alone but fail when run with other tests...
// the answer is sensitive for jdk version
- "udf_java_method"
- ) ++ HiveShim.compatibilityBlackList
+ "udf_java_method",
+
+ // Spark SQL use Long for TimestampType, lose the precision under 100ns
+ "timestamp_1",
+ "timestamp_2"
+ )
/**
* The set of tests that are believed to be working in catalyst. Tests not on whiteList or
@@ -796,8 +799,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"stats_publisher_error_1",
"subq2",
"tablename_with_select",
- "timestamp_1",
- "timestamp_2",
"timestamp_3",
"timestamp_comparison",
"timestamp_lazy",
@@ -818,19 +819,19 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf2",
"udf5",
"udf6",
- "udf7",
+ // "udf7", turn this on after we figure out null vs nan vs infinity
"udf8",
"udf9",
"udf_10_trims",
"udf_E",
"udf_PI",
"udf_abs",
- "udf_acos",
+ // "udf_acos", turn this on after we figure out null vs nan vs infinity
"udf_add",
"udf_array",
"udf_array_contains",
"udf_ascii",
- "udf_asin",
+ // "udf_asin", turn this on after we figure out null vs nan vs infinity
"udf_atan",
"udf_avg",
"udf_bigint",
@@ -918,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_repeat",
"udf_rlike",
"udf_round",
- "udf_round_3",
+ // "udf_round_3", TODO: FIX THIS failed due to cast exception
"udf_rpad",
"udf_rtrim",
"udf_second",
@@ -932,7 +933,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_stddev_pop",
"udf_stddev_samp",
"udf_string",
- "udf_struct",
+ // "udf_struct", TODO: FIX THIS and enable it.
"udf_substring",
"udf_subtract",
"udf_sum",
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 615b07e74d53..a17546d70624 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-sql_${scala.binary.version}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index fbf2c7d8cbc0..9929f318c1e3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -17,38 +17,34 @@
package org.apache.spark.sql.hive
-import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
+import java.io.File
import java.net.{URL, URLClassLoader}
import java.sql.Timestamp
-import java.util.{ArrayList => JArrayList}
-import org.apache.hadoop.hive.ql.parse.VariableSubstitution
+import org.apache.hadoop.hive.common.StatsSetupConst
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.spark.sql.catalyst.ParserDialect
import scala.collection.JavaConversions._
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.HashMap
import scala.language.implicitConversions
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.parse.VariableSubstitution
-import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand}
+import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand}
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
-import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy}
-import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.sources.DataSourceStrategy
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -147,6 +143,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
getConf("spark.sql.hive.metastore.barrierPrefixes", "")
.split(",").filterNot(_ == "")
+ /*
+ * hive thrift server use background spark sql thread pool to execute sql queries
+ */
+ protected[hive] def hiveThriftServerAsync: Boolean =
+ getConf("spark.sql.hive.thriftServer.async", "true").toBoolean
+
@transient
protected[sql] lazy val substitutor = new VariableSubstitution()
@@ -269,13 +271,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* @since 1.3.0
*/
def refreshTable(tableName: String): Unit = {
- // TODO: Database support...
- catalog.refreshTable("default", tableName)
+ val dbAndTableName = tableName.split("\\.")
+ catalog.refreshTable(dbAndTableName.lift(dbAndTableName.size -2)
+ .getOrElse(catalog.client.currentDatabase), dbAndTableName.last)
}
protected[hive] def invalidateTable(tableName: String): Unit = {
- // TODO: Database support...
- catalog.invalidateTable("default", tableName)
+ val dbAndTableName = tableName.split("\\.")
+ catalog.invalidateTable(dbAndTableName.lift(dbAndTableName.size -2)
+ .getOrElse(catalog.client.currentDatabase), dbAndTableName.last)
}
/**
@@ -331,7 +335,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
val tableParameters = relation.hiveQlTable.getParameters
val oldTotalSize =
- Option(tableParameters.get(HiveShim.getStatsSetupConstTotalSize))
+ Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE))
.map(_.toLong)
.getOrElse(0L)
val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable)
@@ -342,7 +346,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.client.alterTable(
relation.table.copy(
properties = relation.table.properties +
- (HiveShim.getStatsSetupConstTotalSize -> newTotalSize.toString)))
+ (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString)))
}
case otherRelation =>
throw new UnsupportedOperationException(
@@ -371,10 +375,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
- override protected[sql] lazy val functionRegistry =
- new HiveFunctionRegistry with OverrideFunctionRegistry {
- override def conf: CatalystConf = currentSession().conf
- }
+ override protected[sql] lazy val functionRegistry: FunctionRegistry =
+ new OverrideFunctionRegistry(new HiveFunctionRegistry(FunctionRegistry.builtin))
/* An analyzer that uses the Hive metastore. */
@transient
@@ -564,7 +566,7 @@ private[hive] object HiveContext {
case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8")
case (decimal: java.math.BigDecimal, DecimalType()) =>
// Hive strips trailing zeros so use its toString
- HiveShim.createDecimal(decimal).toString
+ HiveDecimal.create(decimal).toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 24cd33508263..d4f1ae8ee01d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _}
+import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
@@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -241,15 +243,16 @@ private[hive] trait HiveInspectors {
def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null
case poi: WritableConstantStringObjectInspector =>
- UTF8String(poi.getWritableConstantValue.toString)
+ UTF8String.fromString(poi.getWritableConstantValue.toString)
case poi: WritableConstantHiveVarcharObjectInspector =>
- UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue)
+ UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue)
case poi: WritableConstantHiveDecimalObjectInspector =>
HiveShim.toCatalystDecimal(
PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
poi.getWritableConstantValue.getHiveDecimal)
case poi: WritableConstantTimestampObjectInspector =>
- poi.getWritableConstantValue.getTimestamp.clone()
+ val t = poi.getWritableConstantValue
+ t.getSeconds * 10000000L + t.getNanos / 100L
case poi: WritableConstantIntObjectInspector =>
poi.getWritableConstantValue.get()
case poi: WritableConstantDoubleObjectInspector =>
@@ -286,13 +289,13 @@ private[hive] trait HiveInspectors {
case pi: PrimitiveObjectInspector => pi match {
// We think HiveVarchar is also a String
case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
- UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
+ UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
case hvoi: HiveVarcharObjectInspector =>
- UTF8String(hvoi.getPrimitiveJavaObject(data).getValue)
+ UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
case x: StringObjectInspector if x.preferWritable() =>
- UTF8String(x.getPrimitiveWritableObject(data).toString)
+ UTF8String.fromString(x.getPrimitiveWritableObject(data).toString)
case x: StringObjectInspector =>
- UTF8String(x.getPrimitiveJavaObject(data))
+ UTF8String.fromString(x.getPrimitiveJavaObject(data))
case x: IntObjectInspector if x.preferWritable() => x.get(data)
case x: BooleanObjectInspector if x.preferWritable() => x.get(data)
case x: FloatObjectInspector if x.preferWritable() => x.get(data)
@@ -312,11 +315,11 @@ private[hive] trait HiveInspectors {
case x: DateObjectInspector if x.preferWritable() =>
DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
- // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object
- // if next timestamp is null, so Timestamp object is cloned
case x: TimestampObjectInspector if x.preferWritable() =>
- x.getPrimitiveWritableObject(data).getTimestamp.clone()
- case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone()
+ val t = x.getPrimitiveWritableObject(data)
+ t.getSeconds * 10000000L + t.getNanos / 100
+ case ti: TimestampObjectInspector =>
+ DateUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
case _ => pi.getPrimitiveJavaObject(data)
}
case li: ListObjectInspector =>
@@ -350,17 +353,20 @@ private[hive] trait HiveInspectors {
new HiveVarchar(s, s.size)
case _: JavaHiveDecimalObjectInspector =>
- (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal)
+ (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)
case _: JavaDateObjectInspector =>
(o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int])
+ case _: JavaTimestampObjectInspector =>
+ (o: Any) => DateUtils.toJavaTimestamp(o.asInstanceOf[Long])
+
case soi: StandardStructObjectInspector =>
val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
(o: Any) => {
if (o != null) {
val struct = soi.create()
- (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach {
+ (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach {
(field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
}
struct
@@ -439,36 +445,36 @@ private[hive] trait HiveInspectors {
case _ if a == null => null
case x: PrimitiveObjectInspector => x match {
// TODO we don't support the HiveVarcharObjectInspector yet.
- case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a)
+ case _: StringObjectInspector if x.preferWritable() => getStringWritable(a)
case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString()
- case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a)
+ case _: IntObjectInspector if x.preferWritable() => getIntWritable(a)
case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer]
- case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a)
+ case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a)
case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean]
- case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a)
+ case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a)
case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float]
- case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a)
+ case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a)
case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double]
- case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a)
+ case _: LongObjectInspector if x.preferWritable() => getLongWritable(a)
case _: LongObjectInspector => a.asInstanceOf[java.lang.Long]
- case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a)
+ case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a)
case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short]
- case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a)
+ case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a)
case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte]
case _: HiveDecimalObjectInspector if x.preferWritable() =>
- HiveShim.getDecimalWritable(a.asInstanceOf[Decimal])
+ getDecimalWritable(a.asInstanceOf[Decimal])
case _: HiveDecimalObjectInspector =>
- HiveShim.createDecimal(a.asInstanceOf[Decimal].toJavaBigDecimal)
- case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a)
+ HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal)
+ case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a)
case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]]
- case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a)
+ case _: DateObjectInspector if x.preferWritable() => getDateWritable(a)
case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int])
- case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a)
- case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp]
+ case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a)
+ case _: TimestampObjectInspector => DateUtils.toJavaTimestamp(a.asInstanceOf[Long])
}
case x: SettableStructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
- val row = a.asInstanceOf[Row]
+ val row = a.asInstanceOf[InternalRow]
// 1. create the pojo (most likely) object
val result = x.create()
var i = 0
@@ -484,7 +490,7 @@ private[hive] trait HiveInspectors {
result
case x: StructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
- val row = a.asInstanceOf[Row]
+ val row = a.asInstanceOf[InternalRow]
val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
var i = 0
while (i < fieldRefs.length) {
@@ -511,7 +517,7 @@ private[hive] trait HiveInspectors {
}
def wrap(
- row: Row,
+ row: InternalRow,
inspectors: Seq[ObjectInspector],
cache: Array[AnyRef]): Array[AnyRef] = {
var i = 0
@@ -574,31 +580,31 @@ private[hive] trait HiveInspectors {
*/
def toInspector(expr: Expression): ObjectInspector = expr match {
case Literal(value, StringType) =>
- HiveShim.getStringWritableConstantObjectInspector(value)
+ getStringWritableConstantObjectInspector(value)
case Literal(value, IntegerType) =>
- HiveShim.getIntWritableConstantObjectInspector(value)
+ getIntWritableConstantObjectInspector(value)
case Literal(value, DoubleType) =>
- HiveShim.getDoubleWritableConstantObjectInspector(value)
+ getDoubleWritableConstantObjectInspector(value)
case Literal(value, BooleanType) =>
- HiveShim.getBooleanWritableConstantObjectInspector(value)
+ getBooleanWritableConstantObjectInspector(value)
case Literal(value, LongType) =>
- HiveShim.getLongWritableConstantObjectInspector(value)
+ getLongWritableConstantObjectInspector(value)
case Literal(value, FloatType) =>
- HiveShim.getFloatWritableConstantObjectInspector(value)
+ getFloatWritableConstantObjectInspector(value)
case Literal(value, ShortType) =>
- HiveShim.getShortWritableConstantObjectInspector(value)
+ getShortWritableConstantObjectInspector(value)
case Literal(value, ByteType) =>
- HiveShim.getByteWritableConstantObjectInspector(value)
+ getByteWritableConstantObjectInspector(value)
case Literal(value, BinaryType) =>
- HiveShim.getBinaryWritableConstantObjectInspector(value)
+ getBinaryWritableConstantObjectInspector(value)
case Literal(value, DateType) =>
- HiveShim.getDateWritableConstantObjectInspector(value)
+ getDateWritableConstantObjectInspector(value)
case Literal(value, TimestampType) =>
- HiveShim.getTimestampWritableConstantObjectInspector(value)
+ getTimestampWritableConstantObjectInspector(value)
case Literal(value, DecimalType()) =>
- HiveShim.getDecimalWritableConstantObjectInspector(value)
+ getDecimalWritableConstantObjectInspector(value)
case Literal(_, NullType) =>
- HiveShim.getPrimitiveNullWritableConstantObjectInspector
+ getPrimitiveNullWritableConstantObjectInspector
case Literal(value, ArrayType(dt, _)) =>
val listObjectInspector = toInspector(dt)
if (value == null) {
@@ -658,8 +664,8 @@ private[hive] trait HiveInspectors {
case _: JavaFloatObjectInspector => FloatType
case _: WritableBinaryObjectInspector => BinaryType
case _: JavaBinaryObjectInspector => BinaryType
- case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w)
- case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j)
+ case w: WritableHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(w)
+ case j: JavaHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(j)
case _: WritableDateObjectInspector => DateType
case _: JavaDateObjectInspector => DateType
case _: WritableTimestampObjectInspector => TimestampType
@@ -668,10 +674,136 @@ private[hive] trait HiveInspectors {
case _: JavaVoidObjectInspector => NullType
}
+ private def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = {
+ val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo]
+ DecimalType(info.precision(), info.scale())
+ }
+
+ private def getStringWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.stringTypeInfo, getStringWritable(value))
+
+ private def getIntWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.intTypeInfo, getIntWritable(value))
+
+ private def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value))
+
+ private def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value))
+
+ private def getLongWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.longTypeInfo, getLongWritable(value))
+
+ private def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.floatTypeInfo, getFloatWritable(value))
+
+ private def getShortWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.shortTypeInfo, getShortWritable(value))
+
+ private def getByteWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.byteTypeInfo, getByteWritable(value))
+
+ private def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value))
+
+ private def getDateWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.dateTypeInfo, getDateWritable(value))
+
+ private def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value))
+
+ private def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value))
+
+ private def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.voidTypeInfo, null)
+
+ private def getStringWritable(value: Any): hadoopIo.Text =
+ if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].getBytes)
+
+ private def getIntWritable(value: Any): hadoopIo.IntWritable =
+ if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])
+
+ private def getDoubleWritable(value: Any): hiveIo.DoubleWritable =
+ if (value == null) {
+ null
+ } else {
+ new hiveIo.DoubleWritable(value.asInstanceOf[Double])
+ }
+
+ private def getBooleanWritable(value: Any): hadoopIo.BooleanWritable =
+ if (value == null) {
+ null
+ } else {
+ new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])
+ }
+
+ private def getLongWritable(value: Any): hadoopIo.LongWritable =
+ if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])
+
+ private def getFloatWritable(value: Any): hadoopIo.FloatWritable =
+ if (value == null) {
+ null
+ } else {
+ new hadoopIo.FloatWritable(value.asInstanceOf[Float])
+ }
+
+ private def getShortWritable(value: Any): hiveIo.ShortWritable =
+ if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])
+
+ private def getByteWritable(value: Any): hiveIo.ByteWritable =
+ if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])
+
+ private def getBinaryWritable(value: Any): hadoopIo.BytesWritable =
+ if (value == null) {
+ null
+ } else {
+ new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])
+ }
+
+ private def getDateWritable(value: Any): hiveIo.DateWritable =
+ if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int])
+
+ private def getTimestampWritable(value: Any): hiveIo.TimestampWritable =
+ if (value == null) {
+ null
+ } else {
+ new hiveIo.TimestampWritable(DateUtils.toJavaTimestamp(value.asInstanceOf[Long]))
+ }
+
+ private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable =
+ if (value == null) {
+ null
+ } else {
+ // TODO precise, scale?
+ new hiveIo.HiveDecimalWritable(
+ HiveDecimal.create(value.asInstanceOf[Decimal].toJavaBigDecimal))
+ }
+
implicit class typeInfoConversions(dt: DataType) {
import org.apache.hadoop.hive.serde2.typeinfo._
import TypeInfoFactory._
+ private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match {
+ case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale)
+ case _ => new DecimalTypeInfo(
+ HiveShim.UNLIMITED_DECIMAL_PRECISION,
+ HiveShim.UNLIMITED_DECIMAL_SCALE)
+ }
+
def toTypeInfo: TypeInfo = dt match {
case ArrayType(elemType, _) =>
getListTypeInfo(elemType.toTypeInfo)
@@ -690,7 +822,7 @@ private[hive] trait HiveInspectors {
case LongType => longTypeInfo
case ShortType => shortTypeInfo
case StringType => stringTypeInfo
- case d: DecimalType => HiveShim.decimalTypeInfo(d)
+ case d: DecimalType => decimalTypeInfo(d)
case DateType => dateTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 47b85731587d..03d544d070d4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.hive
import com.google.common.base.Objects
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
+
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.hadoop.hive.metastore.Warehouse
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata._
-import org.apache.hadoop.hive.serde2.Deserializer
+import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog}
@@ -37,7 +39,6 @@ import org.apache.spark.sql.parquet.ParquetRelation2
import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources}
-import org.apache.spark.util.Utils
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -142,7 +143,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
provider: String,
options: Map[String, String],
isExternal: Boolean): Unit = {
- val (dbName, tblName) = processDatabaseAndTableName("default", tableName)
+ val dbAndTableName = tableName.split("\\.")
+ val (dbName, tblName) = processDatabaseAndTableName(
+ dbAndTableName
+ .lift(dbAndTableName.size -2)
+ .getOrElse(client.currentDatabase), dbAndTableName.last)
val tableProperties = new scala.collection.mutable.HashMap[String, String]
tableProperties.put("spark.sql.sources.provider", provider)
@@ -202,9 +207,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
def hiveDefaultTableFilePath(tableName: String): String = {
// Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName)
+ val dbAndTableName = tableName.split("\\.")
new Path(
- new Path(client.getDatabase(client.currentDatabase).location),
- tableName.toLowerCase).toString
+ new Path(client.getDatabase(dbAndTableName.lift(dbAndTableName.size -2)
+ .getOrElse(client.currentDatabase)).location),
+ dbAndTableName.last.toLowerCase).toString
}
def tableExists(tableIdentifier: Seq[String]): Boolean = {
@@ -301,7 +308,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
val partitions = metastoreRelation.hiveQlPartitions.map { p =>
val location = p.getLocation
- val values = Row.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
+ val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
})
ParquetPartition(values, location)
@@ -596,7 +603,7 @@ private[hive] case class MetastoreRelation
self: Product =>
- override def equals(other: scala.Any): Boolean = other match {
+ override def equals(other: Any): Boolean = other match {
case relation: MetastoreRelation =>
databaseName == relation.databaseName &&
tableName == relation.tableName &&
@@ -670,8 +677,8 @@ private[hive] case class MetastoreRelation
@transient override lazy val statistics: Statistics = Statistics(
sizeInBytes = {
- val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize)
- val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize)
+ val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)
+ val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE)
// TODO: check if this estimate is valid for tables after partition pruning.
// NOTE: getting `totalSize` directly from params is kind of hacky, but this should be
// relatively cheap if parameters for the table are populated into the metastore. An
@@ -697,11 +704,7 @@ private[hive] case class MetastoreRelation
}
}
- val tableDesc = HiveShim.getTableDesc(
- Class.forName(
- hiveQlTable.getSerializationLib,
- true,
- Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]],
+ val tableDesc = new TableDesc(
hiveQlTable.getInputFormatClass,
// The class of table should be org.apache.hadoop.hive.ql.metadata.Table because
// getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to
@@ -743,6 +746,11 @@ private[hive] case class MetastoreRelation
private[hive] object HiveMetastoreTypes {
def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType)
+ def decimalMetastoreString(decimalType: DecimalType): String = decimalType match {
+ case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)"
+ case _ => s"decimal($HiveShim.UNLIMITED_DECIMAL_PRECISION,$HiveShim.UNLIMITED_DECIMAL_SCALE)"
+ }
+
def toMetastoreType(dt: DataType): String = dt match {
case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
case StructType(fields) =>
@@ -759,7 +767,7 @@ private[hive] object HiveMetastoreTypes {
case BinaryType => "binary"
case BooleanType => "boolean"
case DateType => "date"
- case d: DecimalType => HiveShim.decimalMetastoreString(d)
+ case d: DecimalType => decimalMetastoreString(d)
case TimestampType => "timestamp"
case NullType => "void"
case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 3915ee835685..ca4b80b51b23 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.hive
import java.sql.Date
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.ql.{ErrorMsg, Context}
@@ -39,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution.ExplainCommand
import org.apache.spark.sql.sources.DescribeCommand
+import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema}
import org.apache.spark.sql.types._
@@ -46,6 +45,7 @@ import org.apache.spark.util.random.RandomSampler
/* Implicit conversions */
import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
/**
* Used when we need to start parsing the AST before deciding that we are going to pass the command
@@ -57,7 +57,7 @@ private[hive] case object NativePlaceholder extends LogicalPlan {
override def output: Seq[Attribute] = Seq.empty
}
-case class CreateTableAsSelect(
+private[hive] case class CreateTableAsSelect(
tableDesc: HiveTable,
child: LogicalPlan,
allowExisting: Boolean) extends UnaryNode with Command {
@@ -1307,16 +1307,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
HiveParser.DecimalLiteral)
/* Case insensitive matches */
- val ARRAY = "(?i)ARRAY".r
- val COALESCE = "(?i)COALESCE".r
val COUNT = "(?i)COUNT".r
- val AVG = "(?i)AVG".r
val SUM = "(?i)SUM".r
- val MAX = "(?i)MAX".r
- val MIN = "(?i)MIN".r
- val UPPER = "(?i)UPPER".r
- val LOWER = "(?i)LOWER".r
- val RAND = "(?i)RAND".r
val AND = "(?i)AND".r
val OR = "(?i)OR".r
val NOT = "(?i)NOT".r
@@ -1330,8 +1322,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val BETWEEN = "(?i)BETWEEN".r
val WHEN = "(?i)WHEN".r
val CASE = "(?i)CASE".r
- val SUBSTR = "(?i)SUBSTR(?:ING)?".r
- val SQRT = "(?i)SQRT".r
protected def nodeToExpr(node: Node): Expression = node match {
/* Attribute References */
@@ -1353,18 +1343,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
UnresolvedStar(Some(name))
/* Aggregate Functions */
- case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg))
- case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg))
case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))
case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr))
- case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg))
case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
- case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg))
- case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg))
-
- /* System functions about string operations */
- case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg))
- case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg))
/* Casts */
case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
@@ -1414,7 +1395,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right))
case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right))
case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right))
- case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg))
/* Comparisons */
case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
@@ -1469,17 +1449,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("[", child :: ordinal :: Nil) =>
UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal))
- /* Other functions */
- case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>
- CreateArray(children.map(nodeToExpr))
- case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand()
- case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong)
- case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) =>
- Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType))
- case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>
- Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length))
- case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr))
-
/* Window Functions */
case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) =>
val function = UnresolvedWindowFunction(name, args.map(nodeToExpr))
@@ -1561,6 +1530,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
""".stripMargin)
}
+ /* Case insensitive matches for Window Specification */
+ val PRECEDING = "(?i)preceding".r
+ val FOLLOWING = "(?i)following".r
+ val CURRENT = "(?i)current".r
def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match {
case Token(windowName, Nil) :: Nil =>
// Refer to a window spec defined in the window clause.
@@ -1614,11 +1587,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
} else {
val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame)
def nodeToBoundary(node: Node): FrameBoundary = node match {
- case Token("preceding", Token(count, Nil) :: Nil) =>
- if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt)
- case Token("following", Token(count, Nil) :: Nil) =>
- if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt)
- case Token("current", Nil) => CurrentRow
+ case Token(PRECEDING(), Token(count, Nil) :: Nil) =>
+ if (count.toLowerCase() == "unbounded") {
+ UnboundedPreceding
+ } else {
+ ValuePreceding(count.toInt)
+ }
+ case Token(FOLLOWING(), Token(count, Nil) :: Nil) =>
+ if (count.toLowerCase() == "unbounded") {
+ UnboundedFollowing
+ } else {
+ ValueFollowing(count.toInt)
+ }
+ case Token(CURRENT(), Nil) => CurrentRow
case _ =>
throw new NotImplementedError(
s"""No parse rules for the Window Frame Boundary based on Node ${node.getName}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
new file mode 100644
index 000000000000..d08c59415165
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.{InputStream, OutputStream}
+import java.rmi.server.UID
+
+/* Implicit conversions */
+import scala.collection.JavaConversions._
+import scala.language.implicitConversions
+import scala.reflect.ClassTag
+
+import com.esotericsoftware.kryo.Kryo
+import com.esotericsoftware.kryo.io.{Input, Output}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
+import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
+import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
+import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector
+import org.apache.hadoop.io.Writable
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.types.Decimal
+import org.apache.spark.util.Utils
+
+private[hive] object HiveShim {
+ // Precision and scale to pass for unlimited decimals; these are the same as the precision and
+ // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs)
+ val UNLIMITED_DECIMAL_PRECISION = 38
+ val UNLIMITED_DECIMAL_SCALE = 18
+
+ /*
+ * This function in hive-0.13 become private, but we have to do this to walkaround hive bug
+ */
+ private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) {
+ val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "")
+ val result: StringBuilder = new StringBuilder(old)
+ var first: Boolean = old.isEmpty
+
+ for (col <- cols) {
+ if (first) {
+ first = false
+ } else {
+ result.append(',')
+ }
+ result.append(col)
+ }
+ conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString)
+ }
+
+ /*
+ * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty
+ */
+ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) {
+ if (ids != null && ids.nonEmpty) {
+ ColumnProjectionUtils.appendReadColumns(conf, ids)
+ }
+ if (names != null && names.nonEmpty) {
+ appendReadColumnNames(conf, names)
+ }
+ }
+
+ /*
+ * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that
+ * is needed to initialize before serialization.
+ */
+ def prepareWritable(w: Writable): Writable = {
+ w match {
+ case w: AvroGenericRecordWritable =>
+ w.setRecordReaderID(new UID())
+ case _ =>
+ }
+ w
+ }
+
+ def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
+ if (hdoi.preferWritable()) {
+ Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue,
+ hdoi.precision(), hdoi.scale())
+ } else {
+ Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
+ }
+ }
+
+ /**
+ * This class provides the UDF creation and also the UDF instance serialization and
+ * de-serialization cross process boundary.
+ *
+ * Detail discussion can be found at https://github.com/apache/spark/pull/3640
+ *
+ * @param functionClassName UDF class name
+ */
+ private[hive] case class HiveFunctionWrapper(var functionClassName: String)
+ extends java.io.Externalizable {
+
+ // for Serialization
+ def this() = this(null)
+
+ @transient
+ def deserializeObjectByKryo[T: ClassTag](
+ kryo: Kryo,
+ in: InputStream,
+ clazz: Class[_]): T = {
+ val inp = new Input(in)
+ val t: T = kryo.readObject(inp, clazz).asInstanceOf[T]
+ inp.close()
+ t
+ }
+
+ @transient
+ def serializeObjectByKryo(
+ kryo: Kryo,
+ plan: Object,
+ out: OutputStream) {
+ val output: Output = new Output(out)
+ kryo.writeObject(output, plan)
+ output.close()
+ }
+
+ def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
+ deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz)
+ .asInstanceOf[UDFType]
+ }
+
+ def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
+ serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out)
+ }
+
+ private var instance: AnyRef = null
+
+ def writeExternal(out: java.io.ObjectOutput) {
+ // output the function name
+ out.writeUTF(functionClassName)
+
+ // Write a flag if instance is null or not
+ out.writeBoolean(instance != null)
+ if (instance != null) {
+ // Some of the UDF are serializable, but some others are not
+ // Hive Utilities can handle both cases
+ val baos = new java.io.ByteArrayOutputStream()
+ serializePlan(instance, baos)
+ val functionInBytes = baos.toByteArray
+
+ // output the function bytes
+ out.writeInt(functionInBytes.length)
+ out.write(functionInBytes, 0, functionInBytes.length)
+ }
+ }
+
+ def readExternal(in: java.io.ObjectInput) {
+ // read the function name
+ functionClassName = in.readUTF()
+
+ if (in.readBoolean()) {
+ // if the instance is not null
+ // read the function in bytes
+ val functionInBytesLength = in.readInt()
+ val functionInBytes = new Array[Byte](functionInBytesLength)
+ in.read(functionInBytes, 0, functionInBytesLength)
+
+ // deserialize the function object via Hive Utilities
+ instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes),
+ Utils.getContextOrSparkClassLoader.loadClass(functionClassName))
+ }
+ }
+
+ def createFunction[UDFType <: AnyRef](): UDFType = {
+ if (instance != null) {
+ instance.asInstanceOf[UDFType]
+ } else {
+ val func = Utils.getContextOrSparkClassLoader
+ .loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
+ if (!func.isInstanceOf[UDF]) {
+ // We cache the function if it's no the Simple UDF,
+ // as we always have to create new instance for Simple UDF
+ instance = func
+ }
+ func
+ }
+ }
+ }
+
+ /*
+ * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not.
+ * Fix it through wrapper.
+ */
+ implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = {
+ val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed)
+ f.setCompressCodec(w.compressCodec)
+ f.setCompressType(w.compressType)
+ f.setTableInfo(w.tableInfo)
+ f.setDestTableId(w.destTableId)
+ f
+ }
+
+ /*
+ * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not.
+ * Fix it through wrapper.
+ */
+ private[hive] class ShimFileSinkDesc(
+ var dir: String,
+ var tableInfo: TableDesc,
+ var compressed: Boolean)
+ extends Serializable with Logging {
+ var compressCodec: String = _
+ var compressType: String = _
+ var destTableId: Int = _
+
+ def setCompressed(compressed: Boolean) {
+ this.compressed = compressed
+ }
+
+ def getDirName(): String = dir
+
+ def setDestTableId(destTableId: Int) {
+ this.destTableId = destTableId
+ }
+
+ def setTableInfo(tableInfo: TableDesc) {
+ this.tableInfo = tableInfo
+ }
+
+ def setCompressCodec(intermediateCompressorCodec: String) {
+ compressCodec = intermediateCompressorCodec
+ }
+
+ def setCompressType(intermediateCompressType: String) {
+ compressType = intermediateCompressType
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index c6b65106452b..452b7f0bcc74 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
-import org.apache.spark.sql.catalyst.expressions.{Row, _}
+import org.apache.spark.sql.catalyst.expressions.{InternalRow, _}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -137,7 +137,7 @@ private[hive] trait HiveStrategies {
val partitionLocations = partitions.map(_.getLocation)
if (partitionLocations.isEmpty) {
- PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil
+ PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil
} else {
hiveContext
.read.parquet(partitionLocations: _*)
@@ -165,7 +165,7 @@ private[hive] trait HiveStrategies {
// TODO: Remove this hack for Spark 1.3.
case iae: java.lang.IllegalArgumentException
if iae.getMessage.contains("Can not create a Path from an empty string") =>
- PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil
+ PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil
}
case _ => Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 294fc3bd7d5e..485810320f3c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -25,14 +25,13 @@ import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc}
import org.apache.hadoop.hive.serde2.Deserializer
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
-import org.apache.spark.SerializableWritable
+import org.apache.spark.{Logging, SerializableWritable}
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.Logging
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateUtils
@@ -42,9 +41,9 @@ import org.apache.spark.util.Utils
* A trait for subclasses that handle table scans.
*/
private[hive] sealed trait TableReader {
- def makeRDDForTable(hiveTable: HiveTable): RDD[Row]
+ def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow]
- def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row]
+ def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow]
}
@@ -75,7 +74,7 @@ class HadoopTableReader(
private val _broadcastedHiveConf =
sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf))
- override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] =
+ override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] =
makeRDDForTable(
hiveTable,
Class.forName(
@@ -95,7 +94,7 @@ class HadoopTableReader(
def makeRDDForTable(
hiveTable: HiveTable,
deserializerClass: Class[_ <: Deserializer],
- filterOpt: Option[PathFilter]): RDD[Row] = {
+ filterOpt: Option[PathFilter]): RDD[InternalRow] = {
assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table,
since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""")
@@ -126,7 +125,7 @@ class HadoopTableReader(
deserializedHadoopRDD
}
- override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = {
+ override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] = {
val partitionToDeserializer = partitions.map(part =>
(part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap
makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None)
@@ -145,7 +144,7 @@ class HadoopTableReader(
def makeRDDForPartitionedTable(
partitionToDeserializer: Map[HivePartition,
Class[_ <: Deserializer]],
- filterOpt: Option[PathFilter]): RDD[Row] = {
+ filterOpt: Option[PathFilter]): RDD[InternalRow] = {
// SPARK-5068:get FileStatus and do the filtering locally when the path is not exists
def verifyPartitionPath(
@@ -172,7 +171,7 @@ class HadoopTableReader(
path.toString + tails
}
- val partPath = HiveShim.getDataLocationPath(partition)
+ val partPath = partition.getDataLocation
val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size();
var pathPatternStr = getPathPatternByPath(partNum, partPath)
if (!pathPatternSet.contains(pathPatternStr)) {
@@ -187,7 +186,7 @@ class HadoopTableReader(
val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer)
.map { case (partition, partDeserializer) =>
val partDesc = Utilities.getPartitionDesc(partition)
- val partPath = HiveShim.getDataLocationPath(partition)
+ val partPath = partition.getDataLocation
val inputPathStr = applyFilterIfNeeded(partPath, filterOpt)
val ifc = partDesc.getInputFileFormatClass
.asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
@@ -244,7 +243,7 @@ class HadoopTableReader(
// Even if we don't use any partitions, we still need an empty RDD
if (hivePartitionRDDs.size == 0) {
- new EmptyRDD[Row](sc.sparkContext)
+ new EmptyRDD[InternalRow](sc.sparkContext)
} else {
new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs)
}
@@ -320,12 +319,12 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging {
rawDeser: Deserializer,
nonPartitionKeyAttrs: Seq[(Attribute, Int)],
mutableRow: MutableRow,
- tableDeser: Deserializer): Iterator[Row] = {
+ tableDeser: Deserializer): Iterator[InternalRow] = {
val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) {
rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector]
} else {
- HiveShim.getConvertedOI(
+ ObjectInspectorConverters.getConvertedOI(
rawDeser.getObjectInspector,
tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector]
}
@@ -364,10 +363,10 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging {
row.update(ordinal, HiveShim.toCatalystDecimal(oi, value))
case oi: TimestampObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) =>
- row.update(ordinal, oi.getPrimitiveJavaObject(value).clone())
+ row.setLong(ordinal, DateUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value)))
case oi: DateObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) =>
- row.update(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value)))
+ row.setInt(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value)))
case oi: BinaryObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) =>
row.update(ordinal, oi.getPrimitiveJavaObject(value))
@@ -392,7 +391,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging {
i += 1
}
- mutableRow: Row
+ mutableRow: InternalRow
}
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index 99aa0f1ded3f..982ed63874a5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -27,7 +27,7 @@ import scala.language.reflectiveCalls
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.metastore.api.Database
import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.metastore.TableType
+import org.apache.hadoop.hive.metastore.{TableType => HTableType}
import org.apache.hadoop.hive.metastore.api
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata
@@ -59,8 +59,7 @@ private[hive] class ClientWrapper(
version: HiveVersion,
config: Map[String, String])
extends ClientInterface
- with Logging
- with ReflectionMagic {
+ with Logging {
// Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
private val outputBuffer = new java.io.OutputStream {
@@ -90,6 +89,13 @@ private[hive] class ClientWrapper(
}
}
+ private val shim = version match {
+ case hive.v12 => new Shim_v0_12()
+ case hive.v13 => new Shim_v0_13()
+ case hive.v14 => new Shim_v0_14()
+ }
+
+ // Create an internal session state for this ClientWrapper.
val state = {
val original = Thread.currentThread().getContextClassLoader
Thread.currentThread().setContextClassLoader(getClass.getClassLoader)
@@ -126,16 +132,16 @@ private[hive] class ClientWrapper(
*/
private def withHiveState[A](f: => A): A = synchronized {
val original = Thread.currentThread().getContextClassLoader
+ // This setContextClassLoader is used for Hive 0.12's metastore since Hive 0.12 will not
+ // internally override the context class loader of the current thread with the class loader
+ // associated with the HiveConf in `state`.
Thread.currentThread().setContextClassLoader(getClass.getClassLoader)
+ // Set the thread local metastore client to the client associated with this ClientWrapper.
Hive.set(client)
- version match {
- case hive.v12 =>
- classOf[SessionState]
- .callStatic[SessionState, SessionState]("start", state)
- case hive.v13 =>
- classOf[SessionState]
- .callStatic[SessionState, SessionState]("setCurrentSessionState", state)
- }
+ // Starting from Hive 0.13.0, setCurrentSessionState will use the classLoader associated
+ // with the HiveConf in `state` to override the context class loader of the current
+ // thread.
+ shim.setCurrentSessionState(state)
val ret = try f finally {
Thread.currentThread().setContextClassLoader(original)
}
@@ -193,15 +199,12 @@ private[hive] class ClientWrapper(
properties = h.getParameters.toMap,
serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.toMap,
tableType = h.getTableType match {
- case TableType.MANAGED_TABLE => ManagedTable
- case TableType.EXTERNAL_TABLE => ExternalTable
- case TableType.VIRTUAL_VIEW => VirtualView
- case TableType.INDEX_TABLE => IndexTable
- },
- location = version match {
- case hive.v12 => Option(h.call[URI]("getDataLocation")).map(_.toString)
- case hive.v13 => Option(h.call[Path]("getDataLocation")).map(_.toString)
+ case HTableType.MANAGED_TABLE => ManagedTable
+ case HTableType.EXTERNAL_TABLE => ExternalTable
+ case HTableType.VIRTUAL_VIEW => VirtualView
+ case HTableType.INDEX_TABLE => IndexTable
},
+ location = shim.getDataLocation(h),
inputFormat = Option(h.getInputFormatClass).map(_.getName),
outputFormat = Option(h.getOutputFormatClass).map(_.getName),
serde = Option(h.getSerializationLib),
@@ -231,14 +234,7 @@ private[hive] class ClientWrapper(
// set create time
qlTable.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int])
- version match {
- case hive.v12 =>
- table.location.map(new URI(_)).foreach(u => qlTable.call[URI, Unit]("setDataLocation", u))
- case hive.v13 =>
- table.location
- .map(new org.apache.hadoop.fs.Path(_))
- .foreach(qlTable.call[Path, Unit]("setDataLocation", _))
- }
+ table.location.foreach { loc => shim.setDataLocation(qlTable, loc) }
table.inputFormat.map(toInputFormat).foreach(qlTable.setInputFormatClass)
table.outputFormat.map(toOutputFormat).foreach(qlTable.setOutputFormatClass)
table.serde.foreach(qlTable.setSerializationLib)
@@ -279,13 +275,7 @@ private[hive] class ClientWrapper(
override def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withHiveState {
val qlTable = toQlTable(hTable)
- val qlPartitions = version match {
- case hive.v12 =>
- client.call[metadata.Table, JSet[metadata.Partition]]("getAllPartitionsForPruner", qlTable)
- case hive.v13 =>
- client.call[metadata.Table, JSet[metadata.Partition]]("getAllPartitionsOf", qlTable)
- }
- qlPartitions.toSeq.map(toHivePartition)
+ shim.getAllPartitions(client, qlTable).map(toHivePartition)
}
override def listTables(dbName: String): Seq[String] = withHiveState {
@@ -315,15 +305,7 @@ private[hive] class ClientWrapper(
val tokens: Array[String] = cmd_trimmed.split("\\s+")
// The remainder of the command.
val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
- val proc: CommandProcessor = version match {
- case hive.v12 =>
- classOf[CommandProcessorFactory]
- .callStatic[String, HiveConf, CommandProcessor]("get", tokens(0), conf)
- case hive.v13 =>
- classOf[CommandProcessorFactory]
- .callStatic[Array[String], HiveConf, CommandProcessor]("get", Array(tokens(0)), conf)
- }
-
+ val proc = shim.getCommandProcessor(tokens(0), conf)
proc match {
case driver: Driver =>
val response: CommandProcessorResponse = driver.run(cmd)
@@ -334,21 +316,7 @@ private[hive] class ClientWrapper(
}
driver.setMaxRows(maxRows)
- val results = version match {
- case hive.v12 =>
- val res = new JArrayList[String]
- driver.call[JArrayList[String], Boolean]("getResults", res)
- res.toSeq
- case hive.v13 =>
- val res = new JArrayList[Object]
- driver.call[JList[Object], Boolean]("getResults", res)
- res.map { r =>
- r match {
- case s: String => s
- case a: Array[Object] => a(0).asInstanceOf[String]
- }
- }
- }
+ val results = shim.getDriverResults(driver)
driver.close()
results
@@ -382,8 +350,8 @@ private[hive] class ClientWrapper(
holdDDLTime: Boolean,
inheritTableSpecs: Boolean,
isSkewedStoreAsSubdir: Boolean): Unit = withHiveState {
-
- client.loadPartition(
+ shim.loadPartition(
+ client,
new Path(loadPath), // TODO: Use URI
tableName,
partSpec,
@@ -398,7 +366,8 @@ private[hive] class ClientWrapper(
tableName: String,
replace: Boolean,
holdDDLTime: Boolean): Unit = withHiveState {
- client.loadTable(
+ shim.loadTable(
+ client,
new Path(loadPath),
tableName,
replace,
@@ -413,7 +382,8 @@ private[hive] class ClientWrapper(
numDP: Int,
holdDDLTime: Boolean,
listBucketingEnabled: Boolean): Unit = withHiveState {
- client.loadDynamicPartitions(
+ shim.loadDynamicPartitions(
+ client,
new Path(loadPath),
tableName,
partSpec,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
new file mode 100644
index 000000000000..40c167926c8d
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.client
+
+import java.lang.{Boolean => JBoolean, Integer => JInteger}
+import java.lang.reflect.{Method, Modifier}
+import java.net.URI
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet}
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
+import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
+import org.apache.hadoop.hive.ql.session.SessionState
+
+/**
+ * A shim that defines the interface between ClientWrapper and the underlying Hive library used to
+ * talk to the metastore. Each Hive version has its own implementation of this class, defining
+ * version-specific version of needed functions.
+ *
+ * The guideline for writing shims is:
+ * - always extend from the previous version unless really not possible
+ * - initialize methods in lazy vals, both for quicker access for multiple invocations, and to
+ * avoid runtime errors due to the above guideline.
+ */
+private[client] sealed abstract class Shim {
+
+ def setCurrentSessionState(state: SessionState): Unit
+
+ /**
+ * This shim is necessary because the return type is different on different versions of Hive.
+ * All parameters are the same, though.
+ */
+ def getDataLocation(table: Table): Option[String]
+
+ def setDataLocation(table: Table, loc: String): Unit
+
+ def getAllPartitions(hive: Hive, table: Table): Seq[Partition]
+
+ def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor
+
+ def getDriverResults(driver: Driver): Seq[String]
+
+ def loadPartition(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ partSpec: JMap[String, String],
+ replace: Boolean,
+ holdDDLTime: Boolean,
+ inheritTableSpecs: Boolean,
+ isSkewedStoreAsSubdir: Boolean): Unit
+
+ def loadTable(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ replace: Boolean,
+ holdDDLTime: Boolean): Unit
+
+ def loadDynamicPartitions(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ partSpec: JMap[String, String],
+ replace: Boolean,
+ numDP: Int,
+ holdDDLTime: Boolean,
+ listBucketingEnabled: Boolean): Unit
+
+ protected def findStaticMethod(klass: Class[_], name: String, args: Class[_]*): Method = {
+ val method = findMethod(klass, name, args: _*)
+ require(Modifier.isStatic(method.getModifiers()),
+ s"Method $name of class $klass is not static.")
+ method
+ }
+
+ protected def findMethod(klass: Class[_], name: String, args: Class[_]*): Method = {
+ klass.getMethod(name, args: _*)
+ }
+
+}
+
+private[client] class Shim_v0_12 extends Shim {
+
+ private lazy val startMethod =
+ findStaticMethod(
+ classOf[SessionState],
+ "start",
+ classOf[SessionState])
+ private lazy val getDataLocationMethod = findMethod(classOf[Table], "getDataLocation")
+ private lazy val setDataLocationMethod =
+ findMethod(
+ classOf[Table],
+ "setDataLocation",
+ classOf[URI])
+ private lazy val getAllPartitionsMethod =
+ findMethod(
+ classOf[Hive],
+ "getAllPartitionsForPruner",
+ classOf[Table])
+ private lazy val getCommandProcessorMethod =
+ findStaticMethod(
+ classOf[CommandProcessorFactory],
+ "get",
+ classOf[String],
+ classOf[HiveConf])
+ private lazy val getDriverResultsMethod =
+ findMethod(
+ classOf[Driver],
+ "getResults",
+ classOf[JArrayList[String]])
+ private lazy val loadPartitionMethod =
+ findMethod(
+ classOf[Hive],
+ "loadPartition",
+ classOf[Path],
+ classOf[String],
+ classOf[JMap[String, String]],
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE)
+ private lazy val loadTableMethod =
+ findMethod(
+ classOf[Hive],
+ "loadTable",
+ classOf[Path],
+ classOf[String],
+ JBoolean.TYPE,
+ JBoolean.TYPE)
+ private lazy val loadDynamicPartitionsMethod =
+ findMethod(
+ classOf[Hive],
+ "loadDynamicPartitions",
+ classOf[Path],
+ classOf[String],
+ classOf[JMap[String, String]],
+ JBoolean.TYPE,
+ JInteger.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE)
+
+ override def setCurrentSessionState(state: SessionState): Unit = startMethod.invoke(null, state)
+
+ override def getDataLocation(table: Table): Option[String] =
+ Option(getDataLocationMethod.invoke(table)).map(_.toString())
+
+ override def setDataLocation(table: Table, loc: String): Unit =
+ setDataLocationMethod.invoke(table, new URI(loc))
+
+ override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
+ getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
+
+ override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
+ getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]
+
+ override def getDriverResults(driver: Driver): Seq[String] = {
+ val res = new JArrayList[String]()
+ getDriverResultsMethod.invoke(driver, res)
+ res.toSeq
+ }
+
+ override def loadPartition(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ partSpec: JMap[String, String],
+ replace: Boolean,
+ holdDDLTime: Boolean,
+ inheritTableSpecs: Boolean,
+ isSkewedStoreAsSubdir: Boolean): Unit = {
+ loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean,
+ holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean)
+ }
+
+ override def loadTable(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ replace: Boolean,
+ holdDDLTime: Boolean): Unit = {
+ loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean)
+ }
+
+ override def loadDynamicPartitions(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ partSpec: JMap[String, String],
+ replace: Boolean,
+ numDP: Int,
+ holdDDLTime: Boolean,
+ listBucketingEnabled: Boolean): Unit = {
+ loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean,
+ numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean)
+ }
+
+}
+
+private[client] class Shim_v0_13 extends Shim_v0_12 {
+
+ private lazy val setCurrentSessionStateMethod =
+ findStaticMethod(
+ classOf[SessionState],
+ "setCurrentSessionState",
+ classOf[SessionState])
+ private lazy val setDataLocationMethod =
+ findMethod(
+ classOf[Table],
+ "setDataLocation",
+ classOf[Path])
+ private lazy val getAllPartitionsMethod =
+ findMethod(
+ classOf[Hive],
+ "getAllPartitionsOf",
+ classOf[Table])
+ private lazy val getCommandProcessorMethod =
+ findStaticMethod(
+ classOf[CommandProcessorFactory],
+ "get",
+ classOf[Array[String]],
+ classOf[HiveConf])
+ private lazy val getDriverResultsMethod =
+ findMethod(
+ classOf[Driver],
+ "getResults",
+ classOf[JList[Object]])
+
+ override def setCurrentSessionState(state: SessionState): Unit =
+ setCurrentSessionStateMethod.invoke(null, state)
+
+ override def setDataLocation(table: Table, loc: String): Unit =
+ setDataLocationMethod.invoke(table, new Path(loc))
+
+ override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
+ getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
+
+ override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
+ getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]
+
+ override def getDriverResults(driver: Driver): Seq[String] = {
+ val res = new JArrayList[Object]()
+ getDriverResultsMethod.invoke(driver, res)
+ res.map { r =>
+ r match {
+ case s: String => s
+ case a: Array[Object] => a(0).asInstanceOf[String]
+ }
+ }
+ }
+
+}
+
+private[client] class Shim_v0_14 extends Shim_v0_13 {
+
+ private lazy val loadPartitionMethod =
+ findMethod(
+ classOf[Hive],
+ "loadPartition",
+ classOf[Path],
+ classOf[String],
+ classOf[JMap[String, String]],
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE)
+ private lazy val loadTableMethod =
+ findMethod(
+ classOf[Hive],
+ "loadTable",
+ classOf[Path],
+ classOf[String],
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE)
+ private lazy val loadDynamicPartitionsMethod =
+ findMethod(
+ classOf[Hive],
+ "loadDynamicPartitions",
+ classOf[Path],
+ classOf[String],
+ classOf[JMap[String, String]],
+ JBoolean.TYPE,
+ JInteger.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE,
+ JBoolean.TYPE)
+
+ override def loadPartition(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ partSpec: JMap[String, String],
+ replace: Boolean,
+ holdDDLTime: Boolean,
+ inheritTableSpecs: Boolean,
+ isSkewedStoreAsSubdir: Boolean): Unit = {
+ loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean,
+ holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean,
+ JBoolean.TRUE, JBoolean.FALSE)
+ }
+
+ override def loadTable(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ replace: Boolean,
+ holdDDLTime: Boolean): Unit = {
+ loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean,
+ JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE)
+ }
+
+ override def loadDynamicPartitions(
+ hive: Hive,
+ loadPath: Path,
+ tableName: String,
+ partSpec: JMap[String, String],
+ replace: Boolean,
+ numDP: Int,
+ holdDDLTime: Boolean,
+ listBucketingEnabled: Boolean): Unit = {
+ loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean,
+ numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE)
+ }
+
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
index 196a3d836cab..69cfc5c3c321 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.client
import java.io.File
+import java.lang.reflect.InvocationTargetException
import java.net.{URL, URLClassLoader}
import java.util
@@ -28,6 +29,7 @@ import org.apache.commons.io.{FileUtils, IOUtils}
import org.apache.spark.Logging
import org.apache.spark.deploy.SparkSubmitUtils
+import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.hive.HiveContext
@@ -48,29 +50,27 @@ private[hive] object IsolatedClientLoader {
def hiveVersion(version: String): HiveVersion = version match {
case "12" | "0.12" | "0.12.0" => hive.v12
case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13
+ case "14" | "0.14" | "0.14.0" => hive.v14
}
private def downloadVersion(version: HiveVersion): Seq[URL] = {
- val hiveArtifacts =
- (Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") ++
- (if (version.hasBuiltinsJar) "hive-builtins" :: Nil else Nil))
- .map(a => s"org.apache.hive:$a:${version.fullVersion}") :+
- "com.google.guava:guava:14.0.1" :+
- "org.apache.hadoop:hadoop-client:2.4.0"
+ val hiveArtifacts = version.extraDeps ++
+ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde")
+ .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++
+ Seq("com.google.guava:guava:14.0.1",
+ "org.apache.hadoop:hadoop-client:2.4.0")
val classpath = quietly {
SparkSubmitUtils.resolveMavenCoordinates(
hiveArtifacts.mkString(","),
Some("http://www.datanucleus.org/downloads/maven2"),
- None)
+ None,
+ exclusions = version.exclusions)
}
val allFiles = classpath.split(",").map(new File(_)).toSet
// TODO: Remove copy logic.
- val tempDir = File.createTempFile("hive", "v" + version.toString)
- tempDir.delete()
- tempDir.mkdir()
-
+ val tempDir = Utils.createTempDir(namePrefix = s"hive-${version}")
allFiles.foreach(f => FileUtils.copyFileToDirectory(f, tempDir))
tempDir.listFiles().map(_.toURL)
}
@@ -90,14 +90,14 @@ private[hive] object IsolatedClientLoader {
* `ClientInterface`, unless `isolationOn` is set to `false`.
*
* @param version The version of hive on the classpath. used to pick specific function signatures
- * that are not compatibile accross versions.
+ * that are not compatible across versions.
* @param execJars A collection of jar files that must include hive and hadoop.
* @param config A set of options that will be added to the HiveConf of the constructed client.
* @param isolationOn When true, custom versions of barrier classes will be constructed. Must be
* true unless loading the version of hive that is on Sparks classloader.
- * @param rootClassLoader The system root classloader. Must not know about hive classes.
- * @param baseClassLoader The spark classloader that is used to load shared classes.
- *
+ * @param rootClassLoader The system root classloader.
+ * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know
+ * about Hive classes.
*/
private[hive] class IsolatedClientLoader(
val version: HiveVersion,
@@ -110,7 +110,7 @@ private[hive] class IsolatedClientLoader(
val barrierPrefixes: Seq[String] = Seq.empty)
extends Logging {
- // Check to make sure that the root classloader does not know about Hive.
+ // Check to make sure that the base classloader does not know about Hive.
assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure)
/** All jars used by the hive specific classloader. */
@@ -129,7 +129,7 @@ private[hive] class IsolatedClientLoader(
/** True if `name` refers to a spark class that must see specific version of Hive. */
protected def isBarrierClass(name: String): Boolean =
name.startsWith(classOf[ClientWrapper].getName) ||
- name.startsWith(classOf[ReflectionMagic].getName) ||
+ name.startsWith(classOf[Shim].getName) ||
barrierPrefixes.exists(name.startsWith)
protected def classToPath(name: String): String =
@@ -170,11 +170,16 @@ private[hive] class IsolatedClientLoader(
.newInstance(version, config)
.asInstanceOf[ClientInterface]
} catch {
- case ReflectionException(cnf: NoClassDefFoundError) =>
- throw new ClassNotFoundException(
- s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" +
- "Please make sure that jars for your version of hive and hadoop are included in the " +
- s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.")
+ case e: InvocationTargetException =>
+ if (e.getCause().isInstanceOf[NoClassDefFoundError]) {
+ val cnf = e.getCause().asInstanceOf[NoClassDefFoundError]
+ throw new ClassNotFoundException(
+ s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" +
+ "Please make sure that jars for your version of hive and hadoop are included in the " +
+ s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.")
+ } else {
+ throw e
+ }
} finally {
Thread.currentThread.setContextClassLoader(baseClassLoader)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala
deleted file mode 100644
index c600b158c546..000000000000
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala
+++ /dev/null
@@ -1,208 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.client
-
-import scala.reflect._
-
-/** Unwraps reflection exceptions. */
-private[client] object ReflectionException {
- def unapply(a: Throwable): Option[Throwable] = a match {
- case ite: java.lang.reflect.InvocationTargetException => Option(ite.getCause)
- case _ => None
- }
-}
-
-/**
- * Provides implicit functions on any object for calling methods reflectively.
- */
-protected trait ReflectionMagic {
- /** code for InstanceMagic
- println(
- (1 to 22).map { n =>
- def repeat(str: String => String) = (1 to n).map(i => str(i.toString)).mkString(", ")
- val types = repeat(n => s"A$n <: AnyRef : ClassTag")
- val inArgs = repeat(n => s"a$n: A$n")
- val erasure = repeat(n => s"classTag[A$n].erasure")
- val outArgs = repeat(n => s"a$n")
- s"""|def call[$types, R](name: String, $inArgs): R = {
- | clazz.getMethod(name, $erasure).invoke(a, $outArgs).asInstanceOf[R]
- |}""".stripMargin
- }.mkString("\n")
- )
- */
-
- // scalastyle:off
- protected implicit class InstanceMagic(a: Any) {
- private val clazz = a.getClass
-
- def call[R](name: String): R = {
- clazz.getMethod(name).invoke(a).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, R](name: String, a1: A1): R = {
- clazz.getMethod(name, classTag[A1].erasure).invoke(a, a1).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure).invoke(a, a1, a2).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure).invoke(a, a1, a2, a3).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure).invoke(a, a1, a2, a3, a4).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure).invoke(a, a1, a2, a3, a4, a5).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure).invoke(a, a1, a2, a3, a4, a5, a6).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21).asInstanceOf[R]
- }
- def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, A22 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21, a22: A22): R = {
- clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure, classTag[A22].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22).asInstanceOf[R]
- }
- }
-
- /** code for StaticMagic
- println(
- (1 to 22).map { n =>
- def repeat(str: String => String) = (1 to n).map(i => str(i.toString)).mkString(", ")
- val types = repeat(n => s"A$n <: AnyRef : ClassTag")
- val inArgs = repeat(n => s"a$n: A$n")
- val erasure = repeat(n => s"classTag[A$n].erasure")
- val outArgs = repeat(n => s"a$n")
- s"""|def callStatic[$types, R](name: String, $inArgs): R = {
- | c.getDeclaredMethod(name, $erasure).invoke(c, $outArgs).asInstanceOf[R]
- |}""".stripMargin
- }.mkString("\n")
- )
- */
-
- protected implicit class StaticMagic(c: Class[_]) {
- def callStatic[A1 <: AnyRef : ClassTag, R](name: String, a1: A1): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure).invoke(c, a1).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure).invoke(c, a1, a2).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure).invoke(c, a1, a2, a3).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure).invoke(c, a1, a2, a3, a4).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure).invoke(c, a1, a2, a3, a4, a5).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure).invoke(c, a1, a2, a3, a4, a5, a6).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21).asInstanceOf[R]
- }
- def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, A22 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21, a22: A22): R = {
- c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure, classTag[A22].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22).asInstanceOf[R]
- }
- }
- // scalastyle:on
-}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
index 7db9200d4744..27a3d8f5896c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
@@ -19,15 +19,27 @@ package org.apache.spark.sql.hive
/** Support for interacting with different versions of the HiveMetastoreClient */
package object client {
- private[client] abstract class HiveVersion(val fullVersion: String, val hasBuiltinsJar: Boolean)
+ private[client] abstract class HiveVersion(
+ val fullVersion: String,
+ val extraDeps: Seq[String] = Nil,
+ val exclusions: Seq[String] = Nil)
// scalastyle:off
private[client] object hive {
- case object v10 extends HiveVersion("0.10.0", true)
- case object v11 extends HiveVersion("0.11.0", false)
- case object v12 extends HiveVersion("0.12.0", false)
- case object v13 extends HiveVersion("0.13.1", false)
+ case object v12 extends HiveVersion("0.12.0")
+ case object v13 extends HiveVersion("0.13.1")
+
+ // Hive 0.14 depends on calcite 0.9.2-incubating-SNAPSHOT which does not exist in
+ // maven central anymore, so override those with a version that exists.
+ //
+ // org.pentaho:pentaho-aggdesigner-algorithm is also nowhere to be found, so exclude
+ // it explicitly. If it's needed by the metastore client, users will have to dig it
+ // out of somewhere and use configuration to point Spark at the correct jars.
+ case object v14 extends HiveVersion("0.14.0",
+ Seq("org.apache.calcite:calcite-core:1.3.0-incubating",
+ "org.apache.calcite:calcite-avatica:1.3.0-incubating"),
+ Seq("org.pentaho:pentaho-aggdesigner-algorithm"))
}
// scalastyle:on
-
-}
\ No newline at end of file
+
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
index 7d3ec12c4eb0..0e4a2427a9c1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.{AnalysisException, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.client.{HiveTable, HiveColumn}
@@ -42,7 +42,7 @@ case class CreateTableAsSelect(
def database: String = tableDesc.database
def tableName: String = tableDesc.name
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
lazy val metastoreRelation: MetastoreRelation = {
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
@@ -50,17 +50,25 @@ case class CreateTableAsSelect(
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextInputFormat
- val withSchema =
+ val withFormat =
tableDesc.copy(
- schema =
- query.output.map(c =>
- HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null)),
inputFormat =
tableDesc.inputFormat.orElse(Some(classOf[TextInputFormat].getName)),
outputFormat =
tableDesc.outputFormat
.orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)),
serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName())))
+
+ val withSchema = if (withFormat.schema.isEmpty) {
+ // Hive doesn't support specifying the column list for target table in CTAS
+ // However we don't think SparkSQL should follow that.
+ tableDesc.copy(schema =
+ query.output.map(c =>
+ HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null)))
+ } else {
+ withFormat
+ }
+
hiveContext.catalog.client.createTable(withSchema)
// Get the Metastore Relation
@@ -81,7 +89,7 @@ case class CreateTableAsSelect(
hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd
}
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
override def argString: String = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
index 6fce69b58b85..a89381000ad5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
@@ -21,12 +21,10 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.hive.metastore.api.FieldSchema
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
-import org.apache.spark.sql.execution.{SparkPlan, RunnableCommand}
-import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation}
-import org.apache.spark.sql.hive.HiveShim
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow}
+import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.sql.hive.MetastoreRelation
/**
* Implementation for "describe [extended] table".
@@ -37,7 +35,7 @@ case class DescribeHiveTableCommand(
override val output: Seq[Attribute],
isExtended: Boolean) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
// Trying to mimic the format of Hive's output. But not exactly the same.
var results: Seq[(String, String, String)] = Nil
@@ -59,7 +57,7 @@ case class DescribeHiveTableCommand(
}
results.map { case (name, dataType, comment) =>
- Row(name, dataType, comment)
+ InternalRow(name, dataType, comment)
}
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala
index 60a9bb630d0d..87f8e3f7fcfc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala
@@ -1,34 +1,34 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.execution
-
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
-import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.types.StringType
-
-private[hive]
-case class HiveNativeCommand(sql: String) extends RunnableCommand {
-
- override def output: Seq[AttributeReference] =
- Seq(AttributeReference("result", StringType, nullable = false)())
-
- override def run(sqlContext: SQLContext): Seq[Row] =
- sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_))
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, InternalRow}
+import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.types.StringType
+
+private[hive]
+case class HiveNativeCommand(sql: String) extends RunnableCommand {
+
+ override def output: Seq[AttributeReference] =
+ Seq(AttributeReference("result", StringType, nullable = false)())
+
+ override def run(sqlContext: SQLContext): Seq[InternalRow] =
+ sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(InternalRow(_))
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index 62dc4167b78d..1f5e4af2e474 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -63,7 +63,7 @@ case class HiveTableScan(
BindReferences.bindReference(pred, relation.partitionKeys)
}
- // Create a local copy of hiveconf,so that scan specific modifications should not impact
+ // Create a local copy of hiveconf,so that scan specific modifications should not impact
// other queries
@transient
private[this] val hiveExtraConf = new HiveConf(context.hiveconf)
@@ -72,7 +72,7 @@ case class HiveTableScan(
addColumnMetadataToConf(hiveExtraConf)
@transient
- private[this] val hadoopReader =
+ private[this] val hadoopReader =
new HadoopTableReader(attributes, relation, context, hiveExtraConf)
private[this] def castFromString(value: String, dataType: DataType) = {
@@ -129,7 +129,7 @@ case class HiveTableScan(
}
}
- protected override def doExecute(): RDD[Row] = if (!relation.hiveQlTable.isPartitioned) {
+ protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 8613332186f2..1d306c5d10af 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -19,27 +19,26 @@ package org.apache.spark.sql.hive.execution
import java.util
-import scala.collection.JavaConversions._
-
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.metastore.MetaStoreUtils
-import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.ql.{Context, ErrorMsg}
import org.apache.hadoop.hive.serde2.Serializer
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
import org.apache.hadoop.hive.serde2.objectinspector._
-import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf}
+import org.apache.hadoop.mapred.{FileOutputFormat, JobConf}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow}
import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
+import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc}
-import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.{SerializableWritable, SparkException, TaskContext}
+import scala.collection.JavaConversions._
+
private[hive]
case class InsertIntoHiveTable(
table: MetastoreRelation,
@@ -62,7 +61,7 @@ case class InsertIntoHiveTable(
def output: Seq[Attribute] = child.output
def saveAsHiveFile(
- rdd: RDD[Row],
+ rdd: RDD[InternalRow],
valueClass: Class[_],
fileSinkConf: FileSinkDesc,
conf: SerializableWritable[JobConf],
@@ -84,7 +83,7 @@ case class InsertIntoHiveTable(
writerContainer.commitJob()
// Note that this function is executed on executor side
- def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = {
+ def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
val serializer = newSerializer(fileSinkConf.getTableInfo)
val standardOI = ObjectInspectorUtils
.getStandardObjectInspector(
@@ -121,12 +120,12 @@ case class InsertIntoHiveTable(
*
* Note: this is run once and then kept to avoid double insertions.
*/
- protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
// Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer
// instances within the closure, since Serializer is not serializable while TableDesc is.
val tableDesc = table.tableDesc
val tableLocation = table.hiveQlTable.getDataLocation
- val tmpLocation = HiveShim.getExternalTmpPath(hiveContext, tableLocation)
+ val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri)
val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false)
val isCompressed = sc.hiveconf.getBoolean(
ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal)
@@ -252,12 +251,13 @@ case class InsertIntoHiveTable(
// however for now we return an empty list to simplify compatibility checks with hive, which
// does not return anything for insert operations.
// TODO: implement hive compatibility as rules.
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
- override def executeCollect(): Array[Row] = sideEffectResult.toArray
+ override def executeCollect(): Array[Row] =
+ sideEffectResult.toArray
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 6f27a8626fc1..9d8872aa47d1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.execution
import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader}
+import java.lang.ProcessBuilder.Redirect
import java.util.Properties
import scala.collection.JavaConversions._
@@ -54,19 +55,25 @@ case class ScriptTransformation(
override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
- protected override def doExecute(): RDD[Row] = {
+ protected override def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
val cmd = List("/bin/bash", "-c", script)
val builder = new ProcessBuilder(cmd)
+ // redirectError(Redirect.INHERIT) would consume the error output from buffer and
+ // then print it to stderr (inherit the target from the current Scala process).
+ // If without this there would be 2 issues:
+ // 1) The error msg generated by the script process would be hidden.
+ // 2) If the error msg is too big to chock up the buffer, the input logic would be hung
+ builder.redirectError(Redirect.INHERIT)
val proc = builder.start()
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val reader = new BufferedReader(new InputStreamReader(inputStream))
-
+
val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output)
- val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors {
- var cacheRow: Row = null
+ val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
+ var cacheRow: InternalRow = null
var curLine: String = null
var eof: Boolean = false
@@ -83,7 +90,7 @@ case class ScriptTransformation(
}
}
- def deserialize(): Row = {
+ def deserialize(): InternalRow = {
if (cacheRow != null) return cacheRow
val mutableRow = new SpecificMutableRow(output.map(_.dataType))
@@ -95,7 +102,7 @@ case class ScriptTransformation(
val raw = outputSerde.deserialize(writable)
val dataList = outputSoi.getStructFieldsDataAsList(raw)
val fieldList = outputSoi.getAllStructFieldRefs()
-
+
var i = 0
dataList.foreach( element => {
if (element == null) {
@@ -113,11 +120,11 @@ case class ScriptTransformation(
}
}
- override def next(): Row = {
+ override def next(): InternalRow = {
if (!hasNext) {
throw new NoSuchElementException
}
-
+
if (outputSerde == null) {
val prevLine = curLine
curLine = reader.readLine()
@@ -192,7 +199,7 @@ case class HiveScriptIOSchema (
val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
-
+
def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = {
val (columns, columnTypes) = parseAttrs(input)
val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps)
@@ -206,13 +213,13 @@ case class HiveScriptIOSchema (
}
def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
-
+
val columns = attrs.map {
case aref: AttributeReference => aref.name
case e: NamedExpression => e.name
case _ => null
}
-
+
val columnTypes = attrs.map {
case aref: AttributeReference => aref.dataType
case e: NamedExpression => e.dataType
@@ -221,7 +228,7 @@ case class HiveScriptIOSchema (
(columns, columnTypes)
}
-
+
def initSerDe(serdeClassName: String, columns: Seq[String],
columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = {
@@ -240,7 +247,7 @@ case class HiveScriptIOSchema (
(kv._1.split("'")(1), kv._2.split("'")(1))
}).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
-
+
val properties = new Properties()
properties.putAll(propsMap)
serde.initialize(null, properties)
@@ -261,7 +268,7 @@ case class HiveScriptIOSchema (
null
}
}
-
+
def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = {
if (outputSerde != null) {
outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 0ba94d7b7c64..aad58bfa2e6e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.HiveContext
@@ -39,9 +39,9 @@ import org.apache.spark.util.Utils
private[hive]
case class AnalyzeTable(tableName: String) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
sqlContext.asInstanceOf[HiveContext].analyze(tableName)
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
}
@@ -53,7 +53,7 @@ case class DropTable(
tableName: String,
ifExists: Boolean) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
val ifExistsClause = if (ifExists) "IF EXISTS " else ""
try {
@@ -70,7 +70,7 @@ case class DropTable(
hiveContext.invalidateTable(tableName)
hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName")
hiveContext.catalog.unregisterTable(Seq(tableName))
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
}
@@ -83,7 +83,7 @@ case class AddJar(path: String) extends RunnableCommand {
schema.toAttributes
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
val currentClassLoader = Utils.getContextOrSparkClassLoader
@@ -91,26 +91,32 @@ case class AddJar(path: String) extends RunnableCommand {
val jarURL = new java.io.File(path).toURL
val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader)
Thread.currentThread.setContextClassLoader(newClassLoader)
- org.apache.hadoop.hive.ql.metadata.Hive.get().getConf().setClassLoader(newClassLoader)
-
- // Add jar to isolated hive classloader
+ // We need to explicitly set the class loader associated with the conf in executionHive's
+ // state because this class loader will be used as the context class loader of the current
+ // thread to execute any Hive command.
+ // We cannot use `org.apache.hadoop.hive.ql.metadata.Hive.get().getConf()` because Hive.get()
+ // returns the value of a thread local variable and its HiveConf may not be the HiveConf
+ // associated with `executionHive.state` (for example, HiveContext is created in one thread
+ // and then add jar is called from another thread).
+ hiveContext.executionHive.state.getConf.setClassLoader(newClassLoader)
+ // Add jar to isolated hive (metadataHive) class loader.
hiveContext.runSqlHive(s"ADD JAR $path")
// Add jar to executors
hiveContext.sparkContext.addJar(path)
- Seq(Row(0))
+ Seq(InternalRow(0))
}
}
private[hive]
case class AddFile(path: String) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
hiveContext.runSqlHive(s"ADD FILE $path")
hiveContext.sparkContext.addFile(path)
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
}
@@ -123,12 +129,12 @@ case class CreateMetastoreDataSource(
allowExisting: Boolean,
managedIfNoPath: Boolean) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
if (hiveContext.catalog.tableExists(tableName :: Nil)) {
if (allowExisting) {
- return Seq.empty[Row]
+ return Seq.empty[InternalRow]
} else {
throw new AnalysisException(s"Table $tableName already exists.")
}
@@ -151,7 +157,7 @@ case class CreateMetastoreDataSource(
optionsWithPath,
isExternal)
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
}
@@ -164,7 +170,7 @@ case class CreateMetastoreDataSourceAsSelect(
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sqlContext: SQLContext): Seq[InternalRow] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
var createMetastoreTable = false
var isExternal = true
@@ -188,7 +194,7 @@ case class CreateMetastoreDataSourceAsSelect(
s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.")
case SaveMode.Ignore =>
// Since the table already exists and the save mode is Ignore, we will just return.
- return Seq.empty[Row]
+ return Seq.empty[InternalRow]
case SaveMode.Append =>
// Check if the specified data source match the data source of the existing table.
val resolved = ResolvedDataSource(
@@ -230,7 +236,7 @@ case class CreateMetastoreDataSourceAsSelect(
val data = DataFrame(hiveContext, query)
val df = existingSchema match {
// If we are inserting into an existing table, just use the existing schema.
- case Some(schema) => sqlContext.createDataFrame(data.queryExecution.toRdd, schema)
+ case Some(schema) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, schema)
case None => data
}
@@ -253,6 +259,6 @@ case class CreateMetastoreDataSourceAsSelect(
// Refresh the cache of the table in the catalog.
hiveContext.refreshTable(tableName)
- Seq.empty[Row]
+ Seq.empty[InternalRow]
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index bb116e3ab7de..4986b1ea9d90 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -17,11 +17,9 @@
package org.apache.spark.sql.hive
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
-import org.apache.spark.sql.AnalysisException
-
import scala.collection.mutable.ArrayBuffer
+import scala.collection.JavaConversions._
+import scala.util.Try
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
@@ -30,47 +28,55 @@ import org.apache.hadoop.hive.ql.exec._
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import org.apache.spark.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types._
-/* Implicit conversions */
-import scala.collection.JavaConversions._
-private[hive] abstract class HiveFunctionRegistry
+private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
extends analysis.FunctionRegistry with HiveInspectors {
def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
- def lookupFunction(name: String, children: Seq[Expression]): Expression = {
- // We only look it up to see if it exists, but do not include it in the HiveUDF since it is
- // not always serializable.
- val functionInfo: FunctionInfo =
- Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
- sys.error(s"Couldn't find function $name"))
-
- val functionClassName = functionInfo.getFunctionClass.getName
-
- if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
- } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
- } else if (
- classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
- } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
- } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
- } else {
- sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
+ override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ Try(underlying.lookupFunction(name, children)).getOrElse {
+ // We only look it up to see if it exists, but do not include it in the HiveUDF since it is
+ // not always serializable.
+ val functionInfo: FunctionInfo =
+ Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
+ throw new AnalysisException(s"undefined function $name"))
+
+ val functionClassName = functionInfo.getFunctionClass.getName
+
+ if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
+ } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
+ } else if (
+ classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
+ } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
+ } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
+ } else {
+ sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
+ }
}
}
+
+ override def registerFunction(name: String, builder: FunctionBuilder): Unit =
+ throw new UnsupportedOperationException
}
private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
@@ -78,6 +84,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre
type UDFType = UDF
+ override def deterministic: Boolean = isUDFDeterministic
+
override def nullable: Boolean = true
@transient
@@ -112,8 +120,10 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre
@transient
protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
+ override def isThreadSafe: Boolean = false
+
// TODO: Finish input output types.
- override def eval(input: Row): Any = {
+ override def eval(input: InternalRow): Any = {
unwrap(
FunctionRegistry.invoke(method, function, conversionHelper
.convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
@@ -140,6 +150,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
extends Expression with HiveInspectors with Logging {
type UDFType = GenericUDF
+ override def deterministic: Boolean = isUDFDeterministic
+
override def nullable: Boolean = true
@transient
@@ -168,7 +180,9 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
lazy val dataType: DataType = inspectorToDataType(returnInspector)
- override def eval(input: Row): Any = {
+ override def isThreadSafe: Boolean = false
+
+ override def eval(input: InternalRow): Any = {
returnInspector // Make sure initialized.
var i = 0
@@ -335,7 +349,7 @@ private[hive] case class HiveWindowFunction(
def nullable: Boolean = true
- override def eval(input: Row): Any =
+ override def eval(input: InternalRow): Any =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
@transient
@@ -359,7 +373,7 @@ private[hive] case class HiveWindowFunction(
evaluator.reset(hiveEvaluatorBuffer)
}
- override def prepareInputParameters(input: Row): AnyRef = {
+ override def prepareInputParameters(input: InternalRow): AnyRef = {
wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length))
}
// Add input parameters for a single row.
@@ -502,7 +516,7 @@ private[hive] case class HiveGenericUdtf(
field => (inspectorToDataType(field.getFieldObjectInspector), true)
}
- override def eval(input: Row): TraversableOnce[Row] = {
+ override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
outputInspector // Make sure initialized.
val inputProjection = new InterpretedProjection(children)
@@ -512,23 +526,23 @@ private[hive] case class HiveGenericUdtf(
}
protected class UDTFCollector extends Collector {
- var collected = new ArrayBuffer[Row]
+ var collected = new ArrayBuffer[InternalRow]
override def collect(input: java.lang.Object) {
// We need to clone the input here because implementations of
// GenericUDTF reuse the same object. Luckily they are always an array, so
// it is easy to clone.
- collected += unwrap(input, outputInspector).asInstanceOf[Row]
+ collected += unwrap(input, outputInspector).asInstanceOf[InternalRow]
}
- def collectRows(): Seq[Row] = {
+ def collectRows(): Seq[InternalRow] = {
val toCollect = collected
- collected = new ArrayBuffer[Row]
+ collected = new ArrayBuffer[InternalRow]
toCollect
}
}
- override def terminate(): TraversableOnce[Row] = {
+ override def terminate(): TraversableOnce[InternalRow] = {
outputInspector // Make sure initialized.
function.close()
collector.collectRows()
@@ -555,12 +569,12 @@ private[hive] case class HiveUdafFunction(
} else {
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}
-
+
private val inspectors = exprs.map(toInspector).toArray
-
- private val function = {
+
+ private val function = {
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
- resolver.getEvaluator(parameterInfo)
+ resolver.getEvaluator(parameterInfo)
}
private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
@@ -568,15 +582,15 @@ private[hive] case class HiveUdafFunction(
private val buffer =
function.getNewAggregationBuffer
- override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector)
+ override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector)
@transient
val inputProjection = new InterpretedProjection(exprs)
@transient
protected lazy val cached = new Array[AnyRef](exprs.length)
-
- def update(input: Row): Unit = {
+
+ def update(input: InternalRow): Unit = {
val inputs = inputProjection(input)
function.iterate(buffer, wrap(inputs, inspectors, cached))
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index 2bb526b14be3..ee440e304ec1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -35,8 +35,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.Row
import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter}
import org.apache.spark.sql.catalyst.util.DateUtils
-import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc}
-import org.apache.spark.sql.hive.HiveShim._
+import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.types._
/**
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 7c7afc824d7a..92155096202b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -49,7 +49,7 @@ import scala.collection.JavaConversions._
object TestHive
extends TestHiveContext(
new SparkContext(
- "local[2]",
+ System.getProperty("spark.sql.test.master", "local[2]"),
"TestSQLContext",
new SparkConf()
.set("spark.sql.test", "")
diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247
index 27de46fdf22a..84a31a5a6970 100644
--- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247
+++ b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247
@@ -1 +1 @@
--0.0010000000000000009
+-0.001
diff --git a/sql/hive/src/test/resources/hive-contrib-0.13.1.jar b/sql/hive/src/test/resources/hive-contrib-0.13.1.jar
new file mode 100644
index 000000000000..ce0740d9245a
Binary files /dev/null and b/sql/hive/src/test/resources/hive-contrib-0.13.1.jar differ
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 945596db8032..39d315aaeab5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -57,7 +57,7 @@ class CachedTableSuite extends QueryTest {
checkAnswer(
sql("SELECT * FROM src s"),
preCacheResults)
-
+
uncacheTable("src")
assertCached(sql("SELECT * FROM src"), 0)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 80c2d32bf70d..aff0456b37ed 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -26,12 +26,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.io.LongWritable
-import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.expressions.{Literal, Row}
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.{Literal, InternalRow}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
-class HiveInspectorSuite extends FunSuite with HiveInspectors {
+class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
test("Test wrap SettableStructObjectInspector") {
val udaf = new UDAFPercentile.PercentileLongEvaluator()
udaf.init()
@@ -45,7 +46,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors {
classOf[UDAFPercentile.State],
ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector]
- val a = unwrap(state, soi).asInstanceOf[Row]
+ val a = unwrap(state, soi).asInstanceOf[InternalRow]
val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State]
val sfCounts = soi.getStructFieldRef("counts")
@@ -127,7 +128,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors {
}
}
- def checkValues(row1: Seq[Any], row2: Row): Unit = {
+ def checkValues(row1: Seq[Any], row2: InternalRow): Unit = {
row1.zip(row2.toSeq).foreach { case (r1, r2) =>
checkValue(r1, r2)
}
@@ -203,7 +204,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors {
})
checkValues(row,
- unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row])
+ unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[InternalRow])
checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index fa8e11ffec2b..e9bb32667936 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql.hive
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.hive.test.TestHive
-import org.scalatest.FunSuite
import org.apache.spark.sql.test.ExamplePointUDT
import org.apache.spark.sql.types.StructType
-class HiveMetastoreCatalogSuite extends FunSuite {
+class HiveMetastoreCatalogSuite extends SparkFunSuite {
test("struct field should accept underscore in sub-column name") {
val metastr = "struct"
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
index 5a5ea10e3c82..a0d80dc39c10 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
@@ -17,10 +17,9 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.parquet.ParquetTest
-import org.apache.spark.sql.{QueryTest, SQLConf}
+import org.apache.spark.sql.{QueryTest, Row, SQLConf}
case class Cases(lower: String, UPPER: String)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
index 941a2941649b..f765395e148a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.hive
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable}
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-class HiveQlSuite extends FunSuite with BeforeAndAfterAll {
+class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
override def beforeAll() {
if (SessionState.get() == null) {
SessionState.start(new HiveConf())
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 9cc4685499f1..aa5dbe2db690 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -240,7 +240,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
checkAnswer(sql("select key,value from table_with_partition where ds='1' "),
testData.collect().toSeq
)
-
+
// test difference type of field
sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT")
checkAnswer(sql("select key,value from table_with_partition where ds='1' "),
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 58e2d1fbfa73..79a85b24d2f6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -561,30 +561,28 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
}
}
- if (HiveShim.version == "0.13.1") {
- test("scan a parquet table created through a CTAS statement") {
- withSQLConf(
- "spark.sql.hive.convertMetastoreParquet" -> "true",
- SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") {
-
- withTempTable("jt") {
- (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt")
-
- withTable("test_parquet_ctas") {
- sql(
- """CREATE TABLE test_parquet_ctas STORED AS PARQUET
- |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5
- """.stripMargin)
-
- checkAnswer(
- sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "),
- Row(3) :: Row(4) :: Nil)
-
- table("test_parquet_ctas").queryExecution.optimizedPlan match {
- case LogicalRelation(p: ParquetRelation2) => // OK
- case _ =>
- fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}")
- }
+ test("scan a parquet table created through a CTAS statement") {
+ withSQLConf(
+ "spark.sql.hive.convertMetastoreParquet" -> "true",
+ SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") {
+
+ withTempTable("jt") {
+ (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt")
+
+ withTable("test_parquet_ctas") {
+ sql(
+ """CREATE TABLE test_parquet_ctas STORED AS PARQUET
+ |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5
+ """.stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "),
+ Row(3) :: Row(4) :: Nil)
+
+ table("test_parquet_ctas").queryExecution.optimizedPlan match {
+ case LogicalRelation(p: ParquetRelation2) => // OK
+ case _ =>
+ fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}")
}
}
}
@@ -835,4 +833,21 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
(70 to 79).map(i => Row(i, s"str$i")))
}
}
+
+ test("SPARK-8156:create table to specific database by 'use dbname' ") {
+
+ val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
+ sqlContext.sql("""create database if not exists testdb8156""")
+ sqlContext.sql("""use testdb8156""")
+ df.write
+ .format("parquet")
+ .mode(SaveMode.Overwrite)
+ .saveAsTable("ttt3")
+
+ checkAnswer(
+ sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"),
+ Row("ttt3", false))
+ sqlContext.sql("""use default""")
+ sqlContext.sql("""drop database if exists testdb8156 CASCADE""")
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
index 4990092df6a9..017bc2adc103 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
@@ -20,16 +20,17 @@ package org.apache.spark.sql.hive
import com.google.common.io.Files
import org.apache.spark.sql.{QueryTest, _}
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.util.Utils
class QueryPartitionSuite extends QueryTest {
- import org.apache.spark.sql.hive.test.TestHive.implicits._
+
+ private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
+ import ctx.implicits._
+ import ctx.sql
test("SPARK-5068: query data when path doesn't exist"){
- val testData = TestHive.sparkContext.parallelize(
+ val testData = ctx.sparkContext.parallelize(
(1 to 10).map(i => TestData(i, i.toString))).toDF()
testData.registerTempTable("testData")
@@ -48,8 +49,8 @@ class QueryPartitionSuite extends QueryTest {
// test for the exist path
checkAnswer(sql("select key,value from table_with_partition"),
- testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect
- ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect)
+ testData.toDF.collect ++ testData.toDF.collect
+ ++ testData.toDF.collect ++ testData.toDF.collect)
// delete the path of one partition
tmpDir.listFiles
@@ -58,8 +59,7 @@ class QueryPartitionSuite extends QueryTest {
// test for after delete the path
checkAnswer(sql("select key,value from table_with_partition"),
- testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect
- ++ testData.toSchemaRDD.collect)
+ testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect)
sql("DROP TABLE table_with_partition")
sql("DROP TABLE createAndInsertTest")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala
index 8afe5459d4f1..93dcb10f7a29 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala
@@ -17,16 +17,13 @@
package org.apache.spark.sql.hive
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.sql.hive.test.TestHive
-class SerializationSuite extends FunSuite {
+class SerializationSuite extends SparkFunSuite {
test("[SPARK-5840] HiveContext should be serializable") {
- val hiveContext = TestHive
+ val hiveContext = org.apache.spark.sql.hive.test.TestHive
hiveContext.hiveconf
val serializer = new JavaSerializer(new SparkConf()).newInstance()
val bytes = serializer.serialize(hiveContext)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 00a69de9e426..78c94e6490e3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -23,13 +23,18 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.{Row, SQLConf, QueryTest}
import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.execution._
class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
- TestHive.reset()
- TestHive.cacheTables = false
+
+ private lazy val ctx: HiveContext = {
+ val ctx = org.apache.spark.sql.hive.test.TestHive
+ ctx.reset()
+ ctx.cacheTables = false
+ ctx
+ }
+
+ import ctx.sql
test("parse analyze commands") {
def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
@@ -72,17 +77,13 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
test("analyze MetastoreRelations") {
def queryTotalSize(tableName: String): BigInt =
- catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes
+ ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes
// Non-partitioned table
sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect()
sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect()
sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect()
- // TODO: How does it works? needs to add it back for other hive version.
- if (HiveShim.version =="0.12.0") {
- assert(queryTotalSize("analyzeTable") === conf.defaultSizeInBytes)
- }
sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan")
assert(queryTotalSize("analyzeTable") === BigInt(11624))
@@ -110,7 +111,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
|SELECT * FROM src
""".stripMargin).collect()
- assert(queryTotalSize("analyzeTable_part") === conf.defaultSizeInBytes)
+ assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes)
sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan")
@@ -121,9 +122,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
// Try to analyze a temp table
sql("""SELECT * FROM src""").registerTempTable("tempTable")
intercept[UnsupportedOperationException] {
- analyze("tempTable")
+ ctx.analyze("tempTable")
}
- catalog.unregisterTable(Seq("tempTable"))
+ ctx.catalog.unregisterTable(Seq("tempTable"))
}
test("estimates the size of a test MetastoreRelation") {
@@ -151,8 +152,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
val sizes = df.queryExecution.analyzed.collect {
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes
}
- assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold
- && sizes(1) <= conf.autoBroadcastJoinThreshold,
+ assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold
+ && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold,
s"query should contain two relations, each of which has size smaller than autoConvertSize")
// Using `sparkPlan` because for relevant patterns in HashJoin to be
@@ -163,8 +164,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(df, expectedAnswer) // check correctness of output
- TestHive.conf.settings.synchronized {
- val tmp = conf.autoBroadcastJoinThreshold
+ ctx.conf.settings.synchronized {
+ val tmp = ctx.conf.autoBroadcastJoinThreshold
sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""")
df = sql(query)
@@ -207,8 +208,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
.isAssignableFrom(r.getClass) =>
r.statistics.sizeInBytes
}
- assert(sizes.size === 2 && sizes(1) <= conf.autoBroadcastJoinThreshold
- && sizes(0) <= conf.autoBroadcastJoinThreshold,
+ assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold
+ && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold,
s"query should contain two relations, each of which has size smaller than autoConvertSize")
// Using `sparkPlan` because for relevant patterns in HashJoin to be
@@ -221,8 +222,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(df, answer) // check correctness of output
- TestHive.conf.settings.synchronized {
- val tmp = conf.autoBroadcastJoinThreshold
+ ctx.conf.settings.synchronized {
+ val tmp = ctx.conf.autoBroadcastJoinThreshold
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
df = sql(leftSemiJoinQuery)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 8245047626d5..4056dee77757 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -17,20 +17,20 @@
package org.apache.spark.sql.hive
-/* Implicits */
-
import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.hive.test.TestHive._
case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {
+
+ private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
+
test("UDF case insensitive") {
- udf.register("random0", () => { Math.random() })
- udf.register("RANDOM1", () => { Math.random() })
- udf.register("strlenScala", (_: String).length + (_: Int))
- assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
- assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
- assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
+ ctx.udf.register("random0", () => { Math.random() })
+ ctx.udf.register("RANDOM1", () => { Math.random() })
+ ctx.udf.register("strlenScala", (_: String).length + (_: Int))
+ assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
+ assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
+ assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index 321dc8d7322b..9a571650b6e2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -17,18 +17,17 @@
package org.apache.spark.sql.hive.client
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.util.Utils
-import org.scalatest.FunSuite
/**
- * A simple set of tests that call the methods of a hive ClientInterface, loading different version
- * of hive from maven central. These tests are simple in that they are mostly just testing to make
- * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionallity
+ * A simple set of tests that call the methods of a hive ClientInterface, loading different version
+ * of hive from maven central. These tests are simple in that they are mostly just testing to make
+ * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality
* is not fully tested.
*/
-class VersionsSuite extends FunSuite with Logging {
+class VersionsSuite extends SparkFunSuite with Logging {
private def buildConf() = {
lazy val warehousePath = Utils.createTempDir()
lazy val metastorePath = Utils.createTempDir()
@@ -73,7 +72,7 @@ class VersionsSuite extends FunSuite with Logging {
assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'")
}
- private val versions = Seq("12", "13")
+ private val versions = Seq("12", "13", "14")
private var client: ClientInterface = null
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala
index 23ece7e7cf6e..b0d3dd44daed 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.hive.test.TestHiveContext
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll {
+class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll {
ignore("multiple instances not supported") {
test("Multiple Hive Instances") {
(1 to 10).map { i =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 55e5551b6381..c9dd4c0935a7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql.hive.execution
import java.io._
-import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
+import org.scalatest.{BeforeAndAfterAll, GivenWhenThen}
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.sources.DescribeCommand
import org.apache.spark.sql.execution.{SetCommand, ExplainCommand}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
@@ -40,7 +40,7 @@ import org.apache.spark.sql.hive.test.TestHive
* configured using system properties.
*/
abstract class HiveComparisonTest
- extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging {
+ extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging {
/**
* When set, any cache files that result in test failures will be deleted. Used when the test
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 4af31d482ce4..6d8d99ebc816 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -57,7 +57,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
sql(
"""
- |CREATE TEMPORARY FUNCTION udtf_count2
+ |CREATE TEMPORARY FUNCTION udtf_count2
|AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
""".stripMargin)
}
@@ -874,15 +874,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|WITH serdeproperties('s1'='9')
""".stripMargin)
}
- // Now only verify 0.12.0, and ignore other versions due to binary compatibility
- // current TestSerDe.jar is from 0.12.0
- if (HiveShim.version == "0.12.0") {
- sql(s"ADD JAR $testJar")
- sql(
- """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe'
- |WITH serdeproperties('s1'='9')
- """.stripMargin)
- }
sql("DROP TABLE alter1")
}
@@ -890,15 +881,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// this is a test case from mapjoin_addjar.q
val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath
val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath
- if (HiveShim.version == "0.13.1") {
- sql(s"ADD JAR $testJar")
- sql(
- """CREATE TABLE t1(a string, b string)
- |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin)
- sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""")
- sql("select * from src join t1 on src.key = t1.a")
- sql("DROP TABLE t1")
- }
+ sql(s"ADD JAR $testJar")
+ sql(
+ """CREATE TABLE t1(a string, b string)
+ |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin)
+ sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""")
+ sql("select * from src join t1 on src.key = t1.a")
+ sql("DROP TABLE t1")
}
test("ADD FILE command") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 0ba4d1147821..2209fc2f30a3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -61,7 +61,7 @@ class HiveTableScanSuite extends HiveComparisonTest {
TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect()
TestHive.sql("drop table tb")
}
-
+
test("Spark-4077: timestamp query for null value") {
TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null")
TestHive.sql(
@@ -71,11 +71,11 @@ class HiveTableScanSuite extends HiveComparisonTest {
FIELDS TERMINATED BY ','
LINES TERMINATED BY '\n'
""".stripMargin)
- val location =
+ val location =
Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile()
-
+
TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null")
- assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect()
+ assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect()
=== Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null)))
TestHive.sql("DROP TABLE timestamp_query_null")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 7f49eac49057..ce5985888f54 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -101,7 +101,7 @@ class HiveUdfSuite extends QueryTest {
sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg")
TestHive.reset()
}
-
+
test("SPARK-2693 udaf aggregates test") {
checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src").collect().toSeq)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 27863a60145d..984d97d27bf5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
-import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation}
+import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation}
import org.apache.spark.sql.parquet.ParquetRelation2
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.types._
@@ -330,38 +330,54 @@ class SQLQuerySuite extends QueryTest {
"serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE"
)
- if (HiveShim.version =="0.13.1") {
- val origUseParquetDataSource = conf.parquetUseDataSourceApi
- try {
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
- sql(
- """CREATE TABLE ctas5
- | STORED AS parquet AS
- | SELECT key, value
- | FROM src
- | ORDER BY key, value""".stripMargin).collect()
-
- checkExistence(sql("DESC EXTENDED ctas5"), true,
- "name:key", "type:string", "name:value", "ctas5",
- "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat",
- "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat",
- "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe",
- "MANAGED_TABLE"
- )
-
- val default = getConf("spark.sql.hive.convertMetastoreParquet", "true")
- // use the Hive SerDe for parquet tables
- sql("set spark.sql.hive.convertMetastoreParquet = false")
- checkAnswer(
- sql("SELECT key, value FROM ctas5 ORDER BY key, value"),
- sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq)
- sql(s"set spark.sql.hive.convertMetastoreParquet = $default")
- } finally {
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString)
- }
+ val origUseParquetDataSource = conf.parquetUseDataSourceApi
+ try {
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ sql(
+ """CREATE TABLE ctas5
+ | STORED AS parquet AS
+ | SELECT key, value
+ | FROM src
+ | ORDER BY key, value""".stripMargin).collect()
+
+ checkExistence(sql("DESC EXTENDED ctas5"), true,
+ "name:key", "type:string", "name:value", "ctas5",
+ "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat",
+ "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat",
+ "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe",
+ "MANAGED_TABLE"
+ )
+
+ val default = getConf("spark.sql.hive.convertMetastoreParquet", "true")
+ // use the Hive SerDe for parquet tables
+ sql("set spark.sql.hive.convertMetastoreParquet = false")
+ checkAnswer(
+ sql("SELECT key, value FROM ctas5 ORDER BY key, value"),
+ sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq)
+ sql(s"set spark.sql.hive.convertMetastoreParquet = $default")
+ } finally {
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString)
}
}
+ test("specifying the column list for CTAS") {
+ Seq((1, "111111"), (2, "222222")).toDF("key", "value").registerTempTable("mytable1")
+
+ sql("create table gen__tmp(a int, b string) as select key, value from mytable1")
+ checkAnswer(
+ sql("SELECT a, b from gen__tmp"),
+ sql("select key, value from mytable1").collect())
+ sql("DROP TABLE gen__tmp")
+
+ sql("create table gen__tmp(a double, b double) as select key, value from mytable1")
+ checkAnswer(
+ sql("SELECT a, b from gen__tmp"),
+ sql("select cast(key as double), cast(value as double) from mytable1").collect())
+ sql("DROP TABLE gen__tmp")
+
+ sql("drop table mytable1")
+ }
+
test("command substitution") {
sql("set tbl=src")
checkAnswer(
@@ -629,12 +645,20 @@ class SQLQuerySuite extends QueryTest {
.queryExecution.analyzed
}
- test("test script transform") {
+ test("test script transform for stdout") {
val data = (1 to 100000).map { i => (i, i, i) }
data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
assert(100000 ===
sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans")
- .queryExecution.toRdd.count())
+ .queryExecution.toRdd.count())
+ }
+
+ test("test script transform for stderr") {
+ val data = (1 to 100000).map { i => (i, i, i) }
+ data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
+ assert(0 ===
+ sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans")
+ .queryExecution.toRdd.count())
}
test("window function: udaf with aggregate expressin") {
@@ -780,6 +804,42 @@ class SQLQuerySuite extends QueryTest {
).map(i => Row(i._1, i._2, i._3, i._4)))
}
+ test("window function: multiple window expressions in a single expression") {
+ val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+ nums.registerTempTable("nums")
+
+ val expected =
+ Row(1, 1, 1, 55, 1, 57) ::
+ Row(0, 2, 3, 55, 2, 60) ::
+ Row(1, 3, 6, 55, 4, 65) ::
+ Row(0, 4, 10, 55, 6, 71) ::
+ Row(1, 5, 15, 55, 9, 79) ::
+ Row(0, 6, 21, 55, 12, 88) ::
+ Row(1, 7, 28, 55, 16, 99) ::
+ Row(0, 8, 36, 55, 20, 111) ::
+ Row(1, 9, 45, 55, 25, 125) ::
+ Row(0, 10, 55, 55, 30, 140) :: Nil
+
+ val actual = sql(
+ """
+ |SELECT
+ | y,
+ | x,
+ | sum(x) OVER w1 AS running_sum,
+ | sum(x) OVER w2 AS total_sum,
+ | sum(x) OVER w3 AS running_sum_per_y,
+ | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2
+ |FROM nums
+ |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW),
+ | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING),
+ | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
+ """.stripMargin)
+
+ checkAnswer(actual, expected)
+
+ dropTempTable("nums")
+ }
+
test("test case key when") {
(1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
checkAnswer(
@@ -815,6 +875,10 @@ class SQLQuerySuite extends QueryTest {
}
}
+ test("Cast STRING to BIGINT") {
+ checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L))
+ }
+
// `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest
test("udf_java_method") {
checkAnswer(sql(
@@ -870,4 +934,32 @@ class SQLQuerySuite extends QueryTest {
sql("set hive.exec.dynamic.partition.mode=strict")
}
}
+
+ test("Call add jar in a different thread (SPARK-8306)") {
+ @volatile var error: Option[Throwable] = None
+ val thread = new Thread {
+ override def run() {
+ // To make sure this test works, this jar should not be loaded in another place.
+ TestHive.sql(
+ s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}")
+ try {
+ TestHive.sql(
+ """
+ |CREATE TEMPORARY FUNCTION example_max
+ |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax'
+ """.stripMargin)
+ } catch {
+ case throwable: Throwable =>
+ error = Some(throwable)
+ }
+ }
+ }
+ thread.start()
+ thread.join()
+ error match {
+ case Some(throwable) =>
+ fail("CREATE TEMPORARY FUNCTION should not fail.", throwable)
+ case None => // OK
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
index 88c99e35260d..8707f9f936be 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql.hive.orc
import java.io.File
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.InternalRow
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.util.Utils
-import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
+import org.scalatest.BeforeAndAfterAll
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
@@ -38,7 +39,7 @@ case class OrcParData(intField: Int, stringField: String)
case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot
-class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
+class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll {
val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal
def withTempDir(f: File => Unit): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index cdd6e705f4a2..267d22c6b5f1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -21,10 +21,11 @@ import java.io.File
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.ql.io.orc.CompressionKind
-import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
+import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.InternalRow
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
@@ -50,10 +51,7 @@ case class Contact(name: String, phone: String)
case class Person(name: String, age: Int, contacts: Seq[Contact])
-class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll with OrcTest {
- override val sqlContext = TestHive
-
- import TestHive.read
+class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
def getTempFilePath(prefix: String, suffix: String = ""): File = {
val tempFile = File.createTempFile(prefix, suffix)
@@ -68,7 +66,7 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll w
withOrcFile(data) { file =>
checkAnswer(
- read.format("orc").load(file),
+ sqlContext.read.format("orc").load(file),
data.toDF().collect())
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
index 750f0b04aaa8..5daf691aa8c5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -22,13 +22,11 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql._
private[sql] trait OrcTest extends SQLTestUtils {
- protected def hiveContext = sqlContext.asInstanceOf[HiveContext]
+ lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive
import sqlContext.sparkContext
import sqlContext.implicits._
@@ -53,7 +51,7 @@ private[sql] trait OrcTest extends SQLTestUtils {
protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
- withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path)))
+ withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path)))
}
/**
@@ -65,7 +63,7 @@ private[sql] trait OrcTest extends SQLTestUtils {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withOrcDataFrame(data) { df =>
- hiveContext.registerDataFrameAsTable(df, tableName)
+ sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index e62ac909cbd0..3864349cdbd8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -21,8 +21,6 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD}
import org.apache.spark.sql.hive.execution.HiveTableScan
import org.apache.spark.sql.hive.test.TestHive._
@@ -30,7 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan}
import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode}
+import org.apache.spark.sql.{DataFrame, QueryTest, Row, SQLConf, SaveMode}
import org.apache.spark.util.Utils
// The data where the partitioning key exists only in the directory structure.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index cf5ae88dc4be..8648a91cbb99 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -17,20 +17,23 @@
package org.apache.spark.sql.sources
+import java.io.File
+
+import com.google.common.io.Files
import org.apache.hadoop.fs.Path
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException
abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
- override val sqlContext: SQLContext = TestHive
+ override lazy val sqlContext: SQLContext = TestHive
- import sqlContext._
+ import sqlContext.sql
import sqlContext.implicits._
val dataSourceName = classOf[SimpleTextSource].getCanonicalName
@@ -41,19 +44,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
StructField("a", IntegerType, nullable = false),
StructField("b", StringType, nullable = false)))
- val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b")
+ lazy val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b")
- val partitionedTestDF1 = (for {
+ lazy val partitionedTestDF1 = (for {
i <- 1 to 3
p2 <- Seq("foo", "bar")
} yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2")
- val partitionedTestDF2 = (for {
+ lazy val partitionedTestDF2 = (for {
i <- 1 to 3
p2 <- Seq("foo", "bar")
} yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2")
- val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2)
+ lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2)
def checkQueries(df: DataFrame): Unit = {
// Selects everything
@@ -101,7 +104,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath)
checkAnswer(
- read.format(dataSourceName)
+ sqlContext.read.format(dataSourceName)
.option("path", file.getCanonicalPath)
.option("dataSchema", dataSchema.json)
.load(),
@@ -115,7 +118,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath)
checkAnswer(
- read.format(dataSourceName)
+ sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath).orderBy("a"),
testDF.unionAll(testDF).orderBy("a").collect())
@@ -149,7 +152,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.save(file.getCanonicalPath)
checkQueries(
- read.format(dataSourceName)
+ sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath))
}
@@ -170,7 +173,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.save(file.getCanonicalPath)
checkAnswer(
- read.format(dataSourceName)
+ sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath),
partitionedTestDF.collect())
@@ -192,7 +195,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.save(file.getCanonicalPath)
checkAnswer(
- read.format(dataSourceName)
+ sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath),
partitionedTestDF.unionAll(partitionedTestDF).collect())
@@ -214,7 +217,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.save(file.getCanonicalPath)
checkAnswer(
- read.format(dataSourceName)
+ sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath),
partitionedTestDF.collect())
@@ -250,7 +253,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.saveAsTable("t")
withTable("t") {
- checkAnswer(table("t"), testDF.collect())
+ checkAnswer(sqlContext.table("t"), testDF.collect())
}
}
@@ -259,7 +262,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t")
withTable("t") {
- checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect())
+ checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect())
}
}
@@ -278,7 +281,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
withTempTable("t") {
testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t")
- assert(table("t").collect().isEmpty)
+ assert(sqlContext.table("t").collect().isEmpty)
}
}
@@ -289,7 +292,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.saveAsTable("t")
withTable("t") {
- checkQueries(table("t"))
+ checkQueries(sqlContext.table("t"))
}
}
@@ -309,7 +312,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.saveAsTable("t")
withTable("t") {
- checkAnswer(table("t"), partitionedTestDF.collect())
+ checkAnswer(sqlContext.table("t"), partitionedTestDF.collect())
}
}
@@ -329,7 +332,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.saveAsTable("t")
withTable("t") {
- checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect())
+ checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect())
}
}
@@ -349,7 +352,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.saveAsTable("t")
withTable("t") {
- checkAnswer(table("t"), partitionedTestDF.collect())
+ checkAnswer(sqlContext.table("t"), partitionedTestDF.collect())
}
}
@@ -398,7 +401,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.partitionBy("p1", "p2")
.saveAsTable("t")
- assert(table("t").collect().isEmpty)
+ assert(sqlContext.table("t").collect().isEmpty)
}
}
@@ -410,7 +413,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.partitionBy("p1", "p2")
.save(file.getCanonicalPath)
- val df = read
+ val df = sqlContext.read
.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(s"${file.getCanonicalPath}/p1=*/p2=???")
@@ -450,10 +453,24 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
.saveAsTable("t")
withTempTable("t") {
- checkAnswer(table("t"), input.collect())
+ checkAnswer(sqlContext.table("t"), input.collect())
}
}
}
+
+ test("SPARK-7616: adjust column name order accordingly when saving partitioned table") {
+ val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
+
+ df.write
+ .format(dataSourceName)
+ .mode(SaveMode.Overwrite)
+ .partitionBy("c", "a")
+ .saveAsTable("t")
+
+ withTable("t") {
+ checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect())
+ }
+ }
}
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
@@ -485,7 +502,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
-class CommitFailureTestRelationSuite extends FunSuite with SQLTestUtils {
+class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
import TestHive.implicits._
override val sqlContext = TestHive
@@ -535,20 +552,6 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
- test("SPARK-7616: adjust column name order accordingly when saving partitioned table") {
- val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
-
- df.write
- .format("parquet")
- .mode(SaveMode.Overwrite)
- .partitionBy("c", "a")
- .saveAsTable("t")
-
- withTable("t") {
- checkAnswer(table("t"), df.select('b, 'c, 'a).collect())
- }
- }
-
test("SPARK-7868: _temporary directories should be ignored") {
withTempPath { dir =>
val df = Seq("a", "b", "c").zipWithIndex.toDF()
@@ -564,4 +567,99 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect())
}
}
+
+ test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") {
+ withTempDir { dir =>
+ val path = dir.getCanonicalPath
+ val df = Seq(1 -> "a").toDF()
+
+ // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw
+ // since it's not a valid Parquet file.
+ val emptyFile = new File(path, "empty")
+ Files.createParentDirs(emptyFile)
+ Files.touch(emptyFile)
+
+ // This shouldn't throw anything.
+ df.write.format("parquet").mode(SaveMode.Ignore).save(path)
+
+ // This should only complain that the destination directory already exists, rather than file
+ // "empty" is not a Parquet file.
+ assert {
+ intercept[RuntimeException] {
+ df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path)
+ }.getMessage.contains("already exists")
+ }
+
+ // This shouldn't throw anything.
+ df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
+ checkAnswer(read.format("parquet").load(path), df)
+ }
+ }
+
+ test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") {
+ withTempPath { dir =>
+ intercept[AnalysisException] {
+ // Parquet doesn't allow field names with spaces. Here we are intentionally making an
+ // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger
+ // the bug. Please refer to spark-8079 for more details.
+ range(1, 10)
+ .withColumnRenamed("id", "a b")
+ .write
+ .format("parquet")
+ .save(dir.getCanonicalPath)
+ }
+ }
+ }
+
+ test("SPARK-7943:DF created by hiveContext can create table to specific db by saveAstable") {
+
+ val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
+ // use dbname.tablename to specific db
+ sqlContext.sql("""create database if not exists testdb7943""")
+ df.write
+ .format("parquet")
+ .mode(SaveMode.Overwrite)
+ .saveAsTable("testdb7943.tbl7943_1")
+
+ df.write
+ .format("parquet")
+ .mode(SaveMode.Overwrite)
+ .saveAsTable("tbl7943_2")
+
+ intercept[NoSuchDatabaseException] {
+ df.write
+ .format("parquet")
+ .mode(SaveMode.Overwrite)
+ .saveAsTable("testdb7943-2.tbl1")
+ }
+
+ sqlContext.sql("""use testdb7943""")
+
+ df.write
+ .format("parquet")
+ .mode(SaveMode.Overwrite)
+ .saveAsTable("tbl7943_3")
+ df.write
+ .format("parquet")
+ .mode(SaveMode.Overwrite)
+ .saveAsTable("default.tbl7943_4")
+
+ checkAnswer(
+ sqlContext.sql("show TABLES in testdb7943"),
+ Seq(Row("tbl7943_1", false), Row("tbl7943_3", false)))
+
+ val result = sqlContext.sql("show TABLES in default")
+ checkAnswer(
+ result.filter("tableName = 'tbl7943_2'"),
+ Row("tbl7943_2", false))
+
+ checkAnswer(
+ result.filter("tableName = 'tbl7943_4'"),
+ Row("tbl7943_4", false))
+
+ sqlContext.sql("""use default""")
+ sqlContext.sql("""drop table if exists tbl7943_2 """)
+ sqlContext.sql("""drop table if exists tbl7943_4 """)
+ sqlContext.sql("""drop database if exists testdb7943 CASCADE""")
+ }
}
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
deleted file mode 100644
index dbc5e029e204..000000000000
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ /dev/null
@@ -1,457 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.rmi.server.UID
-import java.util.{Properties, ArrayList => JArrayList}
-import java.io.{OutputStream, InputStream}
-
-import scala.collection.JavaConversions._
-import scala.language.implicitConversions
-import scala.reflect.ClassTag
-
-import com.esotericsoftware.kryo.Kryo
-import com.esotericsoftware.kryo.io.Input
-import com.esotericsoftware.kryo.io.Output
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.hive.common.StatsSetupConst
-import org.apache.hadoop.hive.common.`type`.HiveDecimal
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.ql.Context
-import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
-import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
-import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
-import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
-import org.apache.hadoop.hive.serde.serdeConstants
-import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector}
-import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory}
-import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo}
-import org.apache.hadoop.io.{NullWritable, Writable}
-import org.apache.hadoop.mapred.InputFormat
-import org.apache.hadoop.{io => hadoopIo}
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String}
-import org.apache.spark.util.Utils._
-
-/**
- * This class provides the UDF creation and also the UDF instance serialization and
- * de-serialization cross process boundary.
- *
- * Detail discussion can be found at https://github.com/apache/spark/pull/3640
- *
- * @param functionClassName UDF class name
- */
-private[hive] case class HiveFunctionWrapper(var functionClassName: String)
- extends java.io.Externalizable {
-
- // for Serialization
- def this() = this(null)
-
- @transient
- def deserializeObjectByKryo[T: ClassTag](
- kryo: Kryo,
- in: InputStream,
- clazz: Class[_]): T = {
- val inp = new Input(in)
- val t: T = kryo.readObject(inp,clazz).asInstanceOf[T]
- inp.close()
- t
- }
-
- @transient
- def serializeObjectByKryo(
- kryo: Kryo,
- plan: Object,
- out: OutputStream ) {
- val output: Output = new Output(out)
- kryo.writeObject(output, plan)
- output.close()
- }
-
- def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
- deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz)
- .asInstanceOf[UDFType]
- }
-
- def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
- serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out)
- }
-
- private var instance: AnyRef = null
-
- def writeExternal(out: java.io.ObjectOutput) {
- // output the function name
- out.writeUTF(functionClassName)
-
- // Write a flag if instance is null or not
- out.writeBoolean(instance != null)
- if (instance != null) {
- // Some of the UDF are serializable, but some others are not
- // Hive Utilities can handle both cases
- val baos = new java.io.ByteArrayOutputStream()
- serializePlan(instance, baos)
- val functionInBytes = baos.toByteArray
-
- // output the function bytes
- out.writeInt(functionInBytes.length)
- out.write(functionInBytes, 0, functionInBytes.length)
- }
- }
-
- def readExternal(in: java.io.ObjectInput) {
- // read the function name
- functionClassName = in.readUTF()
-
- if (in.readBoolean()) {
- // if the instance is not null
- // read the function in bytes
- val functionInBytesLength = in.readInt()
- val functionInBytes = new Array[Byte](functionInBytesLength)
- in.read(functionInBytes, 0, functionInBytesLength)
-
- // deserialize the function object via Hive Utilities
- instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes),
- getContextOrSparkClassLoader.loadClass(functionClassName))
- }
- }
-
- def createFunction[UDFType <: AnyRef](): UDFType = {
- if (instance != null) {
- instance.asInstanceOf[UDFType]
- } else {
- val func = getContextOrSparkClassLoader
- .loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
- if (!func.isInstanceOf[UDF]) {
- // We cache the function if it's no the Simple UDF,
- // as we always have to create new instance for Simple UDF
- instance = func
- }
- func
- }
- }
-}
-
-/**
- * A compatibility layer for interacting with Hive version 0.13.1.
- */
-private[hive] object HiveShim {
- val version = "0.13.1"
-
- def getTableDesc(
- serdeClass: Class[_ <: Deserializer],
- inputFormatClass: Class[_ <: InputFormat[_, _]],
- outputFormatClass: Class[_],
- properties: Properties) = {
- new TableDesc(inputFormatClass, outputFormatClass, properties)
- }
-
-
- def getStringWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.stringTypeInfo, getStringWritable(value))
-
- def getIntWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.intTypeInfo, getIntWritable(value))
-
- def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value))
-
- def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value))
-
- def getLongWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.longTypeInfo, getLongWritable(value))
-
- def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.floatTypeInfo, getFloatWritable(value))
-
- def getShortWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.shortTypeInfo, getShortWritable(value))
-
- def getByteWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.byteTypeInfo, getByteWritable(value))
-
- def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value))
-
- def getDateWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.dateTypeInfo, getDateWritable(value))
-
- def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value))
-
- def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value))
-
- def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
- PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.voidTypeInfo, null)
-
- def getStringWritable(value: Any): hadoopIo.Text =
- if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString)
-
- def getIntWritable(value: Any): hadoopIo.IntWritable =
- if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])
-
- def getDoubleWritable(value: Any): hiveIo.DoubleWritable =
- if (value == null) {
- null
- } else {
- new hiveIo.DoubleWritable(value.asInstanceOf[Double])
- }
-
- def getBooleanWritable(value: Any): hadoopIo.BooleanWritable =
- if (value == null) {
- null
- } else {
- new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])
- }
-
- def getLongWritable(value: Any): hadoopIo.LongWritable =
- if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])
-
- def getFloatWritable(value: Any): hadoopIo.FloatWritable =
- if (value == null) {
- null
- } else {
- new hadoopIo.FloatWritable(value.asInstanceOf[Float])
- }
-
- def getShortWritable(value: Any): hiveIo.ShortWritable =
- if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])
-
- def getByteWritable(value: Any): hiveIo.ByteWritable =
- if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])
-
- def getBinaryWritable(value: Any): hadoopIo.BytesWritable =
- if (value == null) {
- null
- } else {
- new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])
- }
-
- def getDateWritable(value: Any): hiveIo.DateWritable =
- if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int])
-
- def getTimestampWritable(value: Any): hiveIo.TimestampWritable =
- if (value == null) {
- null
- } else {
- new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
- }
-
- def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable =
- if (value == null) {
- null
- } else {
- // TODO precise, scale?
- new hiveIo.HiveDecimalWritable(
- HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal))
- }
-
- def getPrimitiveNullWritable: NullWritable = NullWritable.get()
-
- def createDriverResultsArray = new JArrayList[Object]
-
- def processResults(results: JArrayList[Object]) = {
- results.map { r =>
- r match {
- case s: String => s
- case a: Array[Object] => a(0).asInstanceOf[String]
- }
- }
- }
-
- def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE
-
- def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE
-
- def createDefaultDBIfNeeded(context: HiveContext) = {
- context.runSqlHive("CREATE DATABASE default")
- context.runSqlHive("USE default")
- }
-
- def getCommandProcessor(cmd: Array[String], conf: HiveConf) = {
- CommandProcessorFactory.get(cmd, conf)
- }
-
- def createDecimal(bd: java.math.BigDecimal): HiveDecimal = {
- HiveDecimal.create(bd)
- }
-
- /*
- * This function in hive-0.13 become private, but we have to do this to walkaround hive bug
- */
- private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) {
- val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "")
- val result: StringBuilder = new StringBuilder(old)
- var first: Boolean = old.isEmpty
-
- for (col <- cols) {
- if (first) {
- first = false
- } else {
- result.append(',')
- }
- result.append(col)
- }
- conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString)
- }
-
- /*
- * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty
- */
- def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) {
- if (ids != null && ids.size > 0) {
- ColumnProjectionUtils.appendReadColumns(conf, ids)
- }
- if (names != null && names.size > 0) {
- appendReadColumnNames(conf, names)
- }
- }
-
- def getExternalTmpPath(context: Context, path: Path) = {
- context.getExternalTmpPath(path.toUri)
- }
-
- def getDataLocationPath(p: Partition) = p.getDataLocation
-
- def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsOf(tbl)
-
- def compatibilityBlackList = Seq()
-
- def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = {
- tbl.setDataLocation(new Path(crtTbl.getLocation()))
- }
-
- /*
- * Bug introdiced in hive-0.13. FileSinkDesc is serializable, but its member path is not.
- * Fix it through wrapper.
- * */
- implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = {
- var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed)
- f.setCompressCodec(w.compressCodec)
- f.setCompressType(w.compressType)
- f.setTableInfo(w.tableInfo)
- f.setDestTableId(w.destTableId)
- f
- }
-
- // Precision and scale to pass for unlimited decimals; these are the same as the precision and
- // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs)
- private val UNLIMITED_DECIMAL_PRECISION = 38
- private val UNLIMITED_DECIMAL_SCALE = 18
-
- def decimalMetastoreString(decimalType: DecimalType): String = decimalType match {
- case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)"
- case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)"
- }
-
- def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match {
- case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale)
- case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE)
- }
-
- def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = {
- val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo]
- DecimalType(info.precision(), info.scale())
- }
-
- def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
- if (hdoi.preferWritable()) {
- Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue,
- hdoi.precision(), hdoi.scale())
- } else {
- Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
- }
- }
-
- def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = {
- ObjectInspectorConverters.getConvertedOI(inputOI, outputOI)
- }
-
- /*
- * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that
- * is needed to initialize before serialization.
- */
- def prepareWritable(w: Writable): Writable = {
- w match {
- case w: AvroGenericRecordWritable =>
- w.setRecordReaderID(new UID())
- case _ =>
- }
- w
- }
-
- def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {
- if (crtTbl != null && crtTbl.getNullFormat() != null) {
- tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat())
- }
- }
-}
-
-/*
- * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not.
- * Fix it through wrapper.
- */
-private[hive] class ShimFileSinkDesc(
- var dir: String,
- var tableInfo: TableDesc,
- var compressed: Boolean)
- extends Serializable with Logging {
- var compressCodec: String = _
- var compressType: String = _
- var destTableId: Int = _
-
- def setCompressed(compressed: Boolean) {
- this.compressed = compressed
- }
-
- def getDirName = dir
-
- def setDestTableId(destTableId: Int) {
- this.destTableId = destTableId
- }
-
- def setTableInfo(tableInfo: TableDesc) {
- this.tableInfo = tableInfo
- }
-
- def setCompressCodec(intermediateCompressorCodec: String) {
- compressCodec = intermediateCompressorCodec
- }
-
- def setCompressType(intermediateCompressType: String) {
- compressType = intermediateCompressType
- }
-}
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 5ab7f4472c38..697895e72fe5 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -40,6 +40,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css
index b22c884bfebd..ec12616b58d8 100644
--- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css
+++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css
@@ -31,7 +31,7 @@
}
.tooltip-inner {
- max-width: 500px !important; // Make sure we only have one line tooltip
+ max-width: 500px !important; /* Make sure we only have one line tooltip */
}
.line {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 25842d502543..9cd9684d3640 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import scala.collection.Map
import scala.collection.mutable.Queue
import scala.reflect.ClassTag
+import scala.util.control.NonFatal
import akka.actor.{Props, SupervisorStrategy}
import org.apache.hadoop.conf.Configuration
@@ -270,6 +271,8 @@ class StreamingContext private[streaming] (
* Create an input stream with any arbitrary user implemented receiver.
* Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html
* @param receiver Custom implementation of Receiver
+ *
+ * @deprecated As of 1.0.0", replaced by `receiverStream`.
*/
@deprecated("Use receiverStream", "1.0.0")
def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = {
@@ -576,18 +579,26 @@ class StreamingContext private[streaming] (
def start(): Unit = synchronized {
state match {
case INITIALIZED =>
- validate()
startSite.set(DStream.getCreationSite())
sparkContext.setCallSite(startSite.get)
StreamingContext.ACTIVATION_LOCK.synchronized {
StreamingContext.assertNoOtherContextIsActive()
- scheduler.start()
- uiTab.foreach(_.attach())
- state = StreamingContextState.ACTIVE
+ try {
+ validate()
+ scheduler.start()
+ state = StreamingContextState.ACTIVE
+ } catch {
+ case NonFatal(e) =>
+ logError("Error starting the context, marking it as stopped", e)
+ scheduler.stop(false)
+ state = StreamingContextState.STOPPED
+ throw e
+ }
StreamingContext.setActiveContext(this)
}
shutdownHookRef = Utils.addShutdownHook(
StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown)
+ uiTab.foreach(_.attach())
logInfo("StreamingContext started")
case ACTIVE =>
logWarning("StreamingContext has already been started")
@@ -608,6 +619,8 @@ class StreamingContext private[streaming] (
* Wait for the execution to stop. Any exceptions that occurs during the execution
* will be thrown in this thread.
* @param timeout time to wait in milliseconds
+ *
+ * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`.
*/
@deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0")
def awaitTermination(timeout: Long) {
@@ -732,6 +745,10 @@ object StreamingContext extends Logging {
}
}
+ /**
+ * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object.
+ * This is kept here only for backward compatibility.
+ */
@deprecated("Replaced by implicit functions in the DStream companion object. This is " +
"kept here only for backward compatibility.", "1.3.0")
def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index b639b94d5ca4..989e3a729ebc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -148,6 +148,9 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
/** The underlying SparkContext */
val sparkContext = new JavaSparkContext(ssc.sc)
+ /**
+ * @deprecated As of 0.9.0, replaced by `sparkContext`
+ */
@deprecated("use sparkContext", "0.9.0")
val sc: JavaSparkContext = sparkContext
@@ -619,6 +622,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
* Wait for the execution to stop. Any exceptions that occurs during the execution
* will be thrown in this thread.
* @param timeout time to wait in milliseconds
+ * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`.
*/
@deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0")
def awaitTermination(timeout: Long): Unit = {
@@ -677,6 +681,7 @@ object JavaStreamingContext {
*
* @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
* @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor.
*/
@deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0")
def getOrCreate(
@@ -699,6 +704,7 @@ object JavaStreamingContext {
* @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
* @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
* file system
+ * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor.
*/
@deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0")
def getOrCreate(
@@ -724,6 +730,7 @@ object JavaStreamingContext {
* file system
* @param createOnError Whether to create a new JavaStreamingContext if there is an
* error in reading checkpoint data.
+ * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor.
*/
@deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0")
def getOrCreate(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 6efcc193bfcc..192aa6a139bc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -603,6 +603,8 @@ abstract class DStream[T: ClassTag] (
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
+ *
+ * @deprecated As of 0.9.0, replaced by `foreachRDD`.
*/
@deprecated("use foreachRDD", "0.9.0")
def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope {
@@ -612,6 +614,8 @@ abstract class DStream[T: ClassTag] (
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
+ *
+ * @deprecated As of 0.9.0, replaced by `foreachRDD`.
*/
@deprecated("use foreachRDD", "0.9.0")
def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index e4ff05e12f20..e76e7eb0dea1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -70,7 +70,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray
// Register the input blocks information into InputInfoTracker
- val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum)
+ val inputInfo = InputInfo(id, blockInfos.flatMap(_.numRecords).sum)
ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
if (blockInfos.nonEmpty) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index 0588517a2de3..92b51ce39234 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.storage.StreamBlockId
import org.apache.spark.streaming.util.RecurringTimer
-import org.apache.spark.util.{SystemClock, Utils}
+import org.apache.spark.util.SystemClock
/** Listener object for BlockGenerator events */
private[streaming] trait BlockGeneratorListener {
@@ -80,6 +80,8 @@ private[streaming] class BlockGenerator(
private val clock = new SystemClock()
private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms")
+ require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value")
+
private val blockIntervalTimer =
new RecurringTimer(clock, blockIntervalMs, updateCurrentBuffer, "BlockGenerator")
private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10)
@@ -191,7 +193,7 @@ private[streaming] class BlockGenerator(
logError(message, t)
listener.onError(message, t)
}
-
+
private def pushBlock(block: Block) {
listener.onPushBlock(block.id, block.buffer)
logInfo("Pushed block " + block.id)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index 651b534ac190..207d64d9414e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -62,7 +62,7 @@ private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockI
private[streaming] class BlockManagerBasedBlockHandler(
blockManager: BlockManager, storageLevel: StorageLevel)
extends ReceivedBlockHandler with Logging {
-
+
def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = {
val putResult: Seq[(BlockId, BlockStatus)] = block match {
case ArrayBufferBlock(arrayBuffer) =>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 92938379b9c1..8be732b64e3a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -138,8 +138,8 @@ private[streaming] class ReceiverSupervisorImpl(
) {
val blockId = blockIdOption.getOrElse(nextBlockId)
val numRecords = receivedBlock match {
- case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size
- case _ => -1
+ case ArrayBufferBlock(arrayBuffer) => Some(arrayBuffer.size.toLong)
+ case _ => None
}
val time = System.currentTimeMillis
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala
index a72efccf2f99..7c0db8a863c6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala
@@ -23,7 +23,9 @@ import org.apache.spark.Logging
import org.apache.spark.streaming.{Time, StreamingContext}
/** To track the information of input stream at specified batch time. */
-private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long)
+private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) {
+ require(numRecords >= 0, "numRecords must not be negative")
+}
/**
* This class manages all the input streams as well as their input data statistics. The information
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 1d1ddaaccf21..4af9b6d3b56a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -126,6 +126,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
eventLoop.post(ErrorReported(msg, e))
}
+ def isStarted(): Boolean = synchronized {
+ eventLoop != null
+ }
+
private def processEvent(event: JobSchedulerEvent) {
try {
event match {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala
index dc11e84f2996..656ac80df897 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala
@@ -24,11 +24,13 @@ import org.apache.spark.streaming.util.WriteAheadLogRecordHandle
/** Information about blocks received by the receiver */
private[streaming] case class ReceivedBlockInfo(
streamId: Int,
- numRecords: Long,
+ numRecords: Option[Long],
metadataOption: Option[Any],
blockStoreResult: ReceivedBlockStoreResult
) {
+ require(numRecords.isEmpty || numRecords.get >= 0, "numRecords must not be negative")
+
@volatile private var _isBlockIdValid = true
def blockId: StreamBlockId = blockStoreResult.blockId
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
index 6a1dd6949b20..9b5e4dc819a2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.streaming
import java.io.NotSerializableException
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.{HashPartitioner, SparkContext, SparkException}
+import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.ReturnStatementInClosureException
@@ -29,7 +29,7 @@ import org.apache.spark.util.ReturnStatementInClosureException
/**
* Test that closures passed to DStream operations are actually cleaned.
*/
-class DStreamClosureSuite extends FunSuite with BeforeAndAfterAll {
+class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll {
private var ssc: StreamingContext = null
override def beforeAll(): Unit = {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala
index e3fb2ef13085..8844c9d74b93 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.streaming
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDDOperationScope
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.ui.UIUtils
@@ -27,7 +27,7 @@ import org.apache.spark.streaming.ui.UIUtils
/**
* Tests whether scope information is passed from DStream operations to RDDs correctly.
*/
-class DStreamScopeSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll {
+class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
private var ssc: StreamingContext = null
private val batchDuration: Duration = Seconds(1)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 23804237bda8..cca8cedb1d08 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -25,7 +25,7 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import org.apache.hadoop.conf.Configuration
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
@@ -41,7 +41,11 @@ import org.apache.spark.util.{ManualClock, Utils}
import WriteAheadLogBasedBlockHandler._
import WriteAheadLogSuite._
-class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
+class ReceivedBlockHandlerSuite
+ extends SparkFunSuite
+ with BeforeAndAfter
+ with Matchers
+ with Logging {
val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
val hadoopConf = new Configuration()
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index b1af8d5eaacf..be305b5e0dfe 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -25,10 +25,10 @@ import scala.language.{implicitConversions, postfixOps}
import scala.util.Random
import org.apache.hadoop.conf.Configuration
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.storage.StreamBlockId
import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
import org.apache.spark.streaming.scheduler._
@@ -37,7 +37,7 @@ import org.apache.spark.streaming.util.WriteAheadLogSuite._
import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
class ReceivedBlockTrackerSuite
- extends FunSuite with BeforeAndAfter with Matchers with Logging {
+ extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
val hadoopConf = new Configuration()
val akkaTimeout = 10 seconds
@@ -224,7 +224,7 @@ class ReceivedBlockTrackerSuite
/** Generate blocks infos using random ids */
def generateBlockInfos(): Seq[ReceivedBlockInfo] = {
- List.fill(5)(ReceivedBlockInfo(streamId, 0, None,
+ List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None,
BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)))))
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index e36c7914b130..819dd2ccfe91 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -25,16 +25,16 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
-import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
+import org.scalatest.{Assertions, BeforeAndAfter}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.Utils
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite}
-class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
+class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging {
val master = "local[2]"
val appName = this.getClass.getSimpleName
@@ -151,6 +151,22 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
assert(StreamingContext.getActive().isEmpty)
}
+ test("start failure should stop internal components") {
+ ssc = new StreamingContext(conf, batchDuration)
+ val inputStream = addInputStream(ssc)
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some(values.sum + state.getOrElse(0))
+ }
+ inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc)
+ // Require that the start fails because checkpoint directory was not set
+ intercept[Exception] {
+ ssc.start()
+ }
+ assert(ssc.getState() === StreamingContextState.STOPPED)
+ assert(ssc.scheduler.isStarted === false)
+ }
+
+
test("start multiple times") {
ssc = new StreamingContext(master, appName, batchDuration)
addInputStream(ssc).register()
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 554cd30223f4..31b1aebf6a8e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer
import scala.language.implicitConversions
import scala.reflect.ClassTag
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import org.scalatest.time.{Span, Seconds => ScalaTestSeconds}
import org.scalatest.concurrent.Eventually.timeout
import org.scalatest.concurrent.PatienceConfiguration
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
import org.apache.spark.streaming.scheduler._
@@ -204,7 +204,7 @@ class BatchCounter(ssc: StreamingContext) {
* This is the base trait for Spark Streaming testsuites. This provides basic functionality
* to run user-defined set of input on user-defined stream operations, and verify the output.
*/
-trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
+trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
// Name of the framework for Spark context
def framework: String = this.getClass.getSimpleName
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
index 441bbf95d015..a08578680cff 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
@@ -27,20 +27,20 @@ import org.scalatest.selenium.WebBrowser
import org.scalatest.time.SpanSugar._
import org.apache.spark._
-
-
-
+import org.apache.spark.ui.SparkUICssErrorHandler
/**
- * Selenium tests for the Spark Web UI.
+ * Selenium tests for the Spark Streaming Web UI.
*/
class UISeleniumSuite
- extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase {
+ extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase {
implicit var webDriver: WebDriver = _
override def beforeAll(): Unit = {
- webDriver = new HtmlUnitDriver
+ webDriver = new HtmlUnitDriver {
+ getWebClient.setCssErrorHandler(new SparkUICssErrorHandler)
+ }
}
override def afterAll(): Unit = {
@@ -197,4 +197,3 @@ class UISeleniumSuite
}
}
}
-
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
index 6859b65c7165..cb017b798b2a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
@@ -21,15 +21,15 @@ import java.io.File
import scala.util.Random
import org.apache.hadoop.conf.Configuration
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter}
import org.apache.spark.util.Utils
-import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
class WriteAheadLogBackedBlockRDDSuite
- extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
+ extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
val conf = new SparkConf()
.setMaster("local[2]")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
index 5478b4184594..2e210397fe7c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.streaming.scheduler
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.streaming.{Time, Duration, StreamingContext}
-class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter {
+class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter {
private var ssc: StreamingContext = _
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala
index e9ab917ab845..d3ca2b58f36c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.streaming.ui
import java.util.TimeZone
import java.util.concurrent.TimeUnit
-import org.scalatest.FunSuite
import org.scalatest.Matchers
-class UIUtilsSuite extends FunSuite with Matchers{
+import org.apache.spark.SparkFunSuite
+
+class UIUtilsSuite extends SparkFunSuite with Matchers{
test("shortTimeUnitString") {
assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS))
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
index 9ebf7b484f42..78fc344b0017 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.streaming.util
import java.io.ByteArrayOutputStream
import java.util.concurrent.TimeUnit._
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class RateLimitedOutputStreamSuite extends FunSuite {
+class RateLimitedOutputStreamSuite extends SparkFunSuite {
private def benchmark[U](f: => U): Long = {
val start = System.nanoTime
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 79098bcf4861..325ff7c74c39 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -28,15 +28,15 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.scalatest.concurrent.Eventually._
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import org.apache.spark.util.{ManualClock, Utils}
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
-class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
+class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
import WriteAheadLogSuite._
-
+
val hadoopConf = new Configuration()
var tempDir: File = null
var testDir: String = null
@@ -359,7 +359,7 @@ object WriteAheadLogSuite {
): FileBasedWriteAheadLog = {
if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000)
val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1)
-
+
// Ensure that 500 does not get sorted after 2000, so put a high base value.
data.foreach { item =>
manualClock.advance(500)
diff --git a/tools/pom.xml b/tools/pom.xml
index 1c6f3e83a181..feffde4c857e 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 2fd17267ac42..62c6354f1e20 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
index 24b289209805..192c6714b240 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
@@ -25,8 +25,7 @@ public final class PlatformDependent {
/**
* Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of
- * this package. This also lets us aovid accidental use of deprecated methods or methods that
- * aren't present in Java 6.
+ * this package. This also lets us avoid accidental use of deprecated methods.
*/
public static final class UNSAFE {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
new file mode 100644
index 000000000000..a35168019549
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.types;
+
+import java.io.Serializable;
+import java.io.UnsupportedEncodingException;
+import java.util.Arrays;
+import javax.annotation.Nullable;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * A UTF-8 String for internal Spark use.
+ *
+ * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison,
+ * search, see http://en.wikipedia.org/wiki/UTF-8 for details.
+ *
+ * Note: This is not designed for general use cases, should not be used outside SQL.
+ */
+public final class UTF8String implements Comparable, Serializable {
+
+ @Nullable
+ private byte[] bytes;
+
+ private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5,
+ 6, 6, 6, 6};
+
+ public static UTF8String fromBytes(byte[] bytes) {
+ return (bytes != null) ? new UTF8String().set(bytes) : null;
+ }
+
+ public static UTF8String fromString(String str) {
+ return (str != null) ? new UTF8String().set(str) : null;
+ }
+
+ /**
+ * Updates the UTF8String with String.
+ */
+ public UTF8String set(final String str) {
+ try {
+ bytes = str.getBytes("utf-8");
+ } catch (UnsupportedEncodingException e) {
+ // Turn the exception into unchecked so we can find out about it at runtime, but
+ // don't need to add lots of boilerplate code everywhere.
+ PlatformDependent.throwException(e);
+ }
+ return this;
+ }
+
+ /**
+ * Updates the UTF8String with byte[], which should be encoded in UTF-8.
+ */
+ public UTF8String set(final byte[] bytes) {
+ this.bytes = bytes;
+ return this;
+ }
+
+ /**
+ * Returns the number of bytes for a code point with the first byte as `b`
+ * @param b The first byte of a code point
+ */
+ public int numBytes(final byte b) {
+ final int offset = (b & 0xFF) - 192;
+ return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1;
+ }
+
+ /**
+ * Returns the number of code points in it.
+ *
+ * This is only used by Substring() when `start` is negative.
+ */
+ public int length() {
+ int len = 0;
+ for (int i = 0; i < bytes.length; i+= numBytes(bytes[i])) {
+ len += 1;
+ }
+ return len;
+ }
+
+ public byte[] getBytes() {
+ return bytes;
+ }
+
+ /**
+ * Returns a substring of this.
+ * @param start the position of first code point
+ * @param until the position after last code point, exclusive.
+ */
+ public UTF8String substring(final int start, final int until) {
+ if (until <= start || start >= bytes.length) {
+ return UTF8String.fromBytes(new byte[0]);
+ }
+
+ int i = 0;
+ int c = 0;
+ for (; i < bytes.length && c < start; i += numBytes(bytes[i])) {
+ c += 1;
+ }
+
+ int j = i;
+ for (; j < bytes.length && c < until; j += numBytes(bytes[i])) {
+ c += 1;
+ }
+
+ return UTF8String.fromBytes(Arrays.copyOfRange(bytes, i, j));
+ }
+
+ public boolean contains(final UTF8String substring) {
+ final byte[] b = substring.getBytes();
+ if (b.length == 0) {
+ return true;
+ }
+
+ for (int i = 0; i <= bytes.length - b.length; i++) {
+ // TODO: Avoid copying.
+ if (bytes[i] == b[0] && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public boolean startsWith(final UTF8String prefix) {
+ final byte[] b = prefix.getBytes();
+ // TODO: Avoid copying.
+ return b.length <= bytes.length && Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b);
+ }
+
+ public boolean endsWith(final UTF8String suffix) {
+ final byte[] b = suffix.getBytes();
+ return b.length <= bytes.length &&
+ Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b);
+ }
+
+ public UTF8String toUpperCase() {
+ return UTF8String.fromString(toString().toUpperCase());
+ }
+
+ public UTF8String toLowerCase() {
+ return UTF8String.fromString(toString().toLowerCase());
+ }
+
+ @Override
+ public String toString() {
+ try {
+ return new String(bytes, "utf-8");
+ } catch (UnsupportedEncodingException e) {
+ // Turn the exception into unchecked so we can find out about it at runtime, but
+ // don't need to add lots of boilerplate code everywhere.
+ PlatformDependent.throwException(e);
+ return "unknown"; // we will never reach here.
+ }
+ }
+
+ @Override
+ public UTF8String clone() {
+ return new UTF8String().set(bytes);
+ }
+
+ @Override
+ public int compareTo(final UTF8String other) {
+ final byte[] b = other.getBytes();
+ for (int i = 0; i < bytes.length && i < b.length; i++) {
+ int res = bytes[i] - b[i];
+ if (res != 0) {
+ return res;
+ }
+ }
+ return bytes.length - b.length;
+ }
+
+ public int compare(final UTF8String other) {
+ return compareTo(other);
+ }
+
+ @Override
+ public boolean equals(final Object other) {
+ if (other instanceof UTF8String) {
+ return Arrays.equals(bytes, ((UTF8String) other).getBytes());
+ } else if (other instanceof String) {
+ // Used only in unit tests.
+ String s = (String) other;
+ return bytes.length >= s.length() && length() == s.length() && toString().equals(s);
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(bytes);
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
index 18393db9f382..a93fc0ee297c 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
@@ -18,7 +18,6 @@
package org.apache.spark.unsafe.bitset;
import junit.framework.Assert;
-import org.apache.spark.unsafe.bitset.BitSet;
import org.junit.Test;
import org.apache.spark.unsafe.memory.MemoryBlock;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
new file mode 100644
index 000000000000..80c179a1b5e7
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -0,0 +1,93 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.unsafe.types;
+
+import java.io.UnsupportedEncodingException;
+
+import junit.framework.Assert;
+import org.junit.Test;
+
+public class UTF8StringSuite {
+
+ private void checkBasic(String str, int len) throws UnsupportedEncodingException {
+ Assert.assertEquals(UTF8String.fromString(str).length(), len);
+ Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len);
+
+ Assert.assertEquals(UTF8String.fromString(str), str);
+ Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), str);
+ Assert.assertEquals(UTF8String.fromString(str).toString(), str);
+ Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str);
+ Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str));
+
+ Assert.assertEquals(UTF8String.fromString(str).hashCode(),
+ UTF8String.fromBytes(str.getBytes("utf8")).hashCode());
+ }
+
+ @Test
+ public void basicTest() throws UnsupportedEncodingException {
+ checkBasic("hello", 5);
+ checkBasic("世 界", 3);
+ }
+
+ @Test
+ public void contains() {
+ Assert.assertTrue(UTF8String.fromString("hello").contains(UTF8String.fromString("ello")));
+ Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("vello")));
+ Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("hellooo")));
+ Assert.assertTrue(UTF8String.fromString("大千世界").contains(UTF8String.fromString("千世")));
+ Assert.assertFalse(UTF8String.fromString("大千世界").contains(UTF8String.fromString("世千")));
+ Assert.assertFalse(
+ UTF8String.fromString("大千世界").contains(UTF8String.fromString("大千世界好")));
+ }
+
+ @Test
+ public void startsWith() {
+ Assert.assertTrue(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hell")));
+ Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("ell")));
+ Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hellooo")));
+ Assert.assertTrue(UTF8String.fromString("数据砖头").startsWith(UTF8String.fromString("数据")));
+ Assert.assertFalse(UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("千")));
+ Assert.assertFalse(
+ UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("大千世界好")));
+ }
+
+ @Test
+ public void endsWith() {
+ Assert.assertTrue(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ello")));
+ Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ellov")));
+ Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("hhhello")));
+ Assert.assertTrue(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世界")));
+ Assert.assertFalse(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世")));
+ Assert.assertFalse(
+ UTF8String.fromString("数据砖头").endsWith(UTF8String.fromString("我的数据砖头")));
+ }
+
+ @Test
+ public void substring() {
+ Assert.assertEquals(
+ UTF8String.fromString("hello").substring(0, 0), UTF8String.fromString(""));
+ Assert.assertEquals(
+ UTF8String.fromString("hello").substring(1, 3), UTF8String.fromString("el"));
+ Assert.assertEquals(
+ UTF8String.fromString("数据砖头").substring(0, 1), UTF8String.fromString("数"));
+ Assert.assertEquals(
+ UTF8String.fromString("数据砖头").substring(1, 3), UTF8String.fromString("据砖"));
+ Assert.assertEquals(
+ UTF8String.fromString("数据砖头").substring(3, 5), UTF8String.fromString("头"));
+ }
+}
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 00d219f83670..644def7501dc 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -39,6 +39,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.hadoophadoop-yarn-api
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
index aaae6f9734a8..77af46c192cc 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
@@ -60,8 +60,11 @@ private[yarn] class AMDelegationTokenRenewer(
private val hadoopUtil = YarnSparkHadoopUtil.get
- private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5)
- private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5)
+ private val credentialsFile = sparkConf.get("spark.yarn.credentials.file")
+ private val daysToKeepFiles =
+ sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5)
+ private val numFilesToKeep =
+ sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5)
/**
* Schedule a login from the keytab and principal set using the --principal and --keytab
@@ -121,7 +124,7 @@ private[yarn] class AMDelegationTokenRenewer(
import scala.concurrent.duration._
try {
val remoteFs = FileSystem.get(hadoopConf)
- val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file"))
+ val credentialsPath = new Path(credentialsFile)
val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis
hadoopUtil.listFilesSorted(
remoteFs, credentialsPath.getParent,
@@ -160,7 +163,7 @@ private[yarn] class AMDelegationTokenRenewer(
val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab)
logInfo("Successfully logged into KDC.")
val tempCreds = keytabLoggedInUGI.getCredentials
- val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file"))
+ val credentialsPath = new Path(credentialsFile)
val dst = credentialsPath.getParent
keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] {
// Get a copy of the credentials
@@ -186,8 +189,7 @@ private[yarn] class AMDelegationTokenRenewer(
}
val nextSuffix = lastCredentialsFileSuffix + 1
val tokenPathStr =
- sparkConf.get("spark.yarn.credentials.file") +
- SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix
+ credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix
val tokenPath = new Path(tokenPathStr)
val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
logInfo("Writing out delegation tokens to " + tempTokenPath.toString)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 760e458972d9..83dafa4a125d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -32,7 +32,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.SparkException
-import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -46,6 +46,14 @@ private[spark] class ApplicationMaster(
client: YarnRMClient)
extends Logging {
+ // Load the properties file with the Spark configuration and set entries as system properties,
+ // so that user code run inside the AM also has access to them.
+ if (args.propertiesFile != null) {
+ Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) =>
+ sys.props(k) = v
+ }
+ }
+
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
// optimal as more containers are available. Might need to handle this better.
@@ -67,6 +75,7 @@ private[spark] class ApplicationMaster(
@volatile private var reporterThread: Thread = _
@volatile private var allocator: YarnAllocator = _
+ private val allocatorLock = new Object()
// Fields used in client mode.
private var rpcEnv: RpcEnv = null
@@ -359,7 +368,9 @@ private[spark] class ApplicationMaster(
}
logDebug(s"Number of pending allocations is $numPendingAllocate. " +
s"Sleeping for $sleepInterval.")
- Thread.sleep(sleepInterval)
+ allocatorLock.synchronized {
+ allocatorLock.wait(sleepInterval)
+ }
} catch {
case e: InterruptedException =>
}
@@ -487,9 +498,11 @@ private[spark] class ApplicationMaster(
new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
}
+ var userArgs = args.userArgs
if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
- System.setProperty("spark.submit.pyFiles",
- PythonRunner.formatPaths(args.pyFiles).mkString(","))
+ // When running pyspark, the app is run using PythonRunner. The second argument is the list
+ // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty.
+ userArgs = Seq(args.primaryPyFile, "") ++ userArgs
}
if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
// TODO(davies): add R dependencies here
@@ -500,9 +513,7 @@ private[spark] class ApplicationMaster(
val userThread = new Thread {
override def run() {
try {
- val mainArgs = new Array[String](args.userArgs.size)
- args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size)
- mainMethod.invoke(null, mainArgs)
+ mainMethod.invoke(null, userArgs.toArray)
finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
logDebug("Done running users class")
} catch {
@@ -546,8 +557,15 @@ private[spark] class ApplicationMaster(
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RequestExecutors(requestedTotal) =>
Option(allocator) match {
- case Some(a) => a.requestTotalExecutors(requestedTotal)
- case None => logWarning("Container allocator is not ready to request executors yet.")
+ case Some(a) =>
+ allocatorLock.synchronized {
+ if (a.requestTotalExecutors(requestedTotal)) {
+ allocatorLock.notifyAll()
+ }
+ }
+
+ case None =>
+ logWarning("Container allocator is not ready to request executors yet.")
}
context.reply(true)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
index ae6dc1094d72..68e9f6b4db7f 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -26,11 +26,11 @@ class ApplicationMasterArguments(val args: Array[String]) {
var userClass: String = null
var primaryPyFile: String = null
var primaryRFile: String = null
- var pyFiles: String = null
- var userArgs: Seq[String] = Seq[String]()
+ var userArgs: Seq[String] = Nil
var executorMemory = 1024
var executorCores = 1
var numExecutors = DEFAULT_NUMBER_EXECUTORS
+ var propertiesFile: String = null
parseArgs(args.toList)
@@ -59,10 +59,6 @@ class ApplicationMasterArguments(val args: Array[String]) {
primaryRFile = value
args = tail
- case ("--py-files") :: value :: tail =>
- pyFiles = value
- args = tail
-
case ("--args" | "--arg") :: value :: tail =>
userArgsBuffer += value
args = tail
@@ -79,6 +75,10 @@ class ApplicationMasterArguments(val args: Array[String]) {
executorCores = value
args = tail
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ args = tail
+
case _ =>
printUsageAndExit(1, args)
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 234051eb7d3b..da1ec2a0fe2e 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -17,18 +17,21 @@
package org.apache.spark.deploy.yarn
-import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException}
+import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException,
+ OutputStreamWriter}
import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException}
import java.nio.ByteBuffer
import java.security.PrivilegedExceptionAction
-import java.util.UUID
+import java.util.{Properties, UUID}
import java.util.zip.{ZipEntry, ZipOutputStream}
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map}
import scala.reflect.runtime.universe
import scala.util.{Try, Success, Failure}
+import scala.util.control.NonFatal
+import com.google.common.base.Charsets.UTF_8
import com.google.common.base.Objects
import com.google.common.io.Files
@@ -121,24 +124,31 @@ private[spark] class Client(
} catch {
case e: Throwable =>
if (appId != null) {
- val appStagingDir = getAppStagingDir(appId)
- try {
- val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false)
- val stagingDirPath = new Path(appStagingDir)
- val fs = FileSystem.get(hadoopConf)
- if (!preserveFiles && fs.exists(stagingDirPath)) {
- logInfo("Deleting staging directory " + stagingDirPath)
- fs.delete(stagingDirPath, true)
- }
- } catch {
- case ioe: IOException =>
- logWarning("Failed to cleanup staging dir " + appStagingDir, ioe)
- }
+ cleanupStagingDir(appId)
}
throw e
}
}
+ /**
+ * Cleanup application staging directory.
+ */
+ private def cleanupStagingDir(appId: ApplicationId): Unit = {
+ val appStagingDir = getAppStagingDir(appId)
+ try {
+ val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false)
+ val stagingDirPath = new Path(appStagingDir)
+ val fs = FileSystem.get(hadoopConf)
+ if (!preserveFiles && fs.exists(stagingDirPath)) {
+ logInfo("Deleting staging directory " + stagingDirPath)
+ fs.delete(stagingDirPath, true)
+ }
+ } catch {
+ case ioe: IOException =>
+ logWarning("Failed to cleanup staging dir " + appStagingDir, ioe)
+ }
+ }
+
/**
* Set up the context for submitting our ApplicationMaster.
* This uses the YarnClientApplication not available in the Yarn alpha API.
@@ -240,7 +250,9 @@ private[spark] class Client(
* This is used for setting up a container launch context for our ApplicationMaster.
* Exposed for testing.
*/
- def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
+ def prepareLocalResources(
+ appStagingDir: String,
+ pySparkArchives: Seq[String]): HashMap[String, LocalResource] = {
logInfo("Preparing resources for our AM container")
// Upload Spark and the application JAR to the remote file system if necessary,
// and add them as local resources to the application master.
@@ -270,20 +282,6 @@ private[spark] class Client(
"for alternatives.")
}
- // If we passed in a keytab, make sure we copy the keytab to the staging directory on
- // HDFS, and setup the relevant environment vars, so the AM can login again.
- if (loginFromKeytab) {
- logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" +
- " via the YARN Secure Distributed Cache.")
- val localUri = new URI(args.keytab)
- val localPath = getQualifiedLocalPath(localUri, hadoopConf)
- val destinationPath = copyFileToRemote(dst, localPath, replication)
- val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf)
- distCacheMgr.addResource(
- destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE,
- sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true)
- }
-
def addDistributedUri(uri: URI): Boolean = {
val uriStr = uri.toString()
if (distributedUris.contains(uriStr)) {
@@ -295,6 +293,57 @@ private[spark] class Client(
}
}
+ /**
+ * Distribute a file to the cluster.
+ *
+ * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied
+ * to HDFS (if not already there) and added to the application's distributed cache.
+ *
+ * @param path URI of the file to distribute.
+ * @param resType Type of resource being distributed.
+ * @param destName Name of the file in the distributed cache.
+ * @param targetDir Subdirectory where to place the file.
+ * @param appMasterOnly Whether to distribute only to the AM.
+ * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the
+ * localized path for non-local paths, or the input `path` for local paths.
+ * The localized path will be null if the URI has already been added to the cache.
+ */
+ def distribute(
+ path: String,
+ resType: LocalResourceType = LocalResourceType.FILE,
+ destName: Option[String] = None,
+ targetDir: Option[String] = None,
+ appMasterOnly: Boolean = false): (Boolean, String) = {
+ val localURI = new URI(path.trim())
+ if (localURI.getScheme != LOCAL_SCHEME) {
+ if (addDistributedUri(localURI)) {
+ val localPath = getQualifiedLocalPath(localURI, hadoopConf)
+ val linkname = targetDir.map(_ + "/").getOrElse("") +
+ destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName())
+ val destPath = copyFileToRemote(dst, localPath, replication)
+ distCacheMgr.addResource(
+ fs, hadoopConf, destPath, localResources, resType, linkname, statCache,
+ appMasterOnly = appMasterOnly)
+ (false, linkname)
+ } else {
+ (false, null)
+ }
+ } else {
+ (true, path.trim())
+ }
+ }
+
+ // If we passed in a keytab, make sure we copy the keytab to the staging directory on
+ // HDFS, and setup the relevant environment vars, so the AM can login again.
+ if (loginFromKeytab) {
+ logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" +
+ " via the YARN Secure Distributed Cache.")
+ val (_, localizedPath) = distribute(args.keytab,
+ destName = Some(sparkConf.get("spark.yarn.keytab")),
+ appMasterOnly = true)
+ require(localizedPath != null, "Keytab file already distributed.")
+ }
+
/**
* Copy the given main resource to the distributed cache if the scheme is not "local".
* Otherwise, set the corresponding key in our SparkConf to handle it downstream.
@@ -307,33 +356,18 @@ private[spark] class Client(
(SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR),
(APP_JAR, args.userJar, CONF_SPARK_USER_JAR),
("log4j.properties", oldLog4jConf.orNull, null)
- ).foreach { case (destName, _localPath, confKey) =>
- val localPath: String = if (_localPath != null) _localPath.trim() else ""
- if (!localPath.isEmpty()) {
- val localURI = new URI(localPath)
- if (localURI.getScheme != LOCAL_SCHEME) {
- if (addDistributedUri(localURI)) {
- val src = getQualifiedLocalPath(localURI, hadoopConf)
- val destPath = copyFileToRemote(dst, src, replication)
- val destFs = FileSystem.get(destPath.toUri(), hadoopConf)
- distCacheMgr.addResource(destFs, hadoopConf, destPath,
- localResources, LocalResourceType.FILE, destName, statCache)
- }
- } else if (confKey != null) {
+ ).foreach { case (destName, path, confKey) =>
+ if (path != null && !path.trim().isEmpty()) {
+ val (isLocal, localizedPath) = distribute(path, destName = Some(destName))
+ if (isLocal && confKey != null) {
+ require(localizedPath != null, s"Path $path already distributed.")
// If the resource is intended for local use only, handle this downstream
// by setting the appropriate property
- sparkConf.set(confKey, localPath)
+ sparkConf.set(confKey, localizedPath)
}
}
}
- createConfArchive().foreach { file =>
- require(addDistributedUri(file.toURI()))
- val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication)
- distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE,
- LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true)
- }
-
/**
* Do the same for any additional resources passed in through ClientArguments.
* Each resource category is represented by a 3-tuple of:
@@ -349,21 +383,10 @@ private[spark] class Client(
).foreach { case (flist, resType, addToClasspath) =>
if (flist != null && !flist.isEmpty()) {
flist.split(',').foreach { file =>
- val localURI = new URI(file.trim())
- if (localURI.getScheme != LOCAL_SCHEME) {
- if (addDistributedUri(localURI)) {
- val localPath = new Path(localURI)
- val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
- val destPath = copyFileToRemote(dst, localPath, replication)
- distCacheMgr.addResource(
- fs, hadoopConf, destPath, localResources, resType, linkname, statCache)
- if (addToClasspath) {
- cachedSecondaryJarLinks += linkname
- }
- }
- } else if (addToClasspath) {
- // Resource is intended for local use only and should be added to the class path
- cachedSecondaryJarLinks += file.trim()
+ val (_, localizedPath) = distribute(file, resType = resType)
+ require(localizedPath != null)
+ if (addToClasspath) {
+ cachedSecondaryJarLinks += localizedPath
}
}
}
@@ -372,11 +395,31 @@ private[spark] class Client(
sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(","))
}
+ if (isClusterMode && args.primaryPyFile != null) {
+ distribute(args.primaryPyFile, appMasterOnly = true)
+ }
+
+ pySparkArchives.foreach { f => distribute(f) }
+
+ // The python files list needs to be treated especially. All files that are not an
+ // archive need to be placed in a subdirectory that will be added to PYTHONPATH.
+ args.pyFiles.foreach { f =>
+ val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None
+ distribute(f, targetDir = targetDir)
+ }
+
+ // Distribute an archive with Hadoop and Spark configuration for the AM.
+ val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(),
+ resType = LocalResourceType.ARCHIVE,
+ destName = Some(LOCALIZED_CONF_DIR),
+ appMasterOnly = true)
+ require(confLocalizedPath != null)
+
localResources
}
/**
- * Create an archive with the Hadoop config files for distribution.
+ * Create an archive with the config files for distribution.
*
* These are only used by the AM, since executors will use the configuration object broadcast by
* the driver. The files are zipped and added to the job as an archive, so that YARN will explode
@@ -388,8 +431,11 @@ private[spark] class Client(
*
* Currently this makes a shallow copy of the conf directory. If there are cases where a
* Hadoop config directory contains subdirectories, this code will have to be fixed.
+ *
+ * The archive also contains some Spark configuration. Namely, it saves the contents of
+ * SparkConf in a file to be loaded by the AM process.
*/
- private def createConfArchive(): Option[File] = {
+ private def createConfArchive(): File = {
val hadoopConfFiles = new HashMap[String, File]()
Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey =>
sys.env.get(envKey).foreach { path =>
@@ -404,28 +450,32 @@ private[spark] class Client(
}
}
- if (!hadoopConfFiles.isEmpty) {
- val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip",
- new File(Utils.getLocalDir(sparkConf)))
+ val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip",
+ new File(Utils.getLocalDir(sparkConf)))
+ val confStream = new ZipOutputStream(new FileOutputStream(confArchive))
- val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive))
- try {
- hadoopConfStream.setLevel(0)
- hadoopConfFiles.foreach { case (name, file) =>
- if (file.canRead()) {
- hadoopConfStream.putNextEntry(new ZipEntry(name))
- Files.copy(file, hadoopConfStream)
- hadoopConfStream.closeEntry()
- }
+ try {
+ confStream.setLevel(0)
+ hadoopConfFiles.foreach { case (name, file) =>
+ if (file.canRead()) {
+ confStream.putNextEntry(new ZipEntry(name))
+ Files.copy(file, confStream)
+ confStream.closeEntry()
}
- } finally {
- hadoopConfStream.close()
}
- Some(hadoopConfArchive)
- } else {
- None
+ // Save Spark configuration to a file in the archive.
+ val props = new Properties()
+ sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) }
+ confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE))
+ val writer = new OutputStreamWriter(confStream, UTF_8)
+ props.store(writer, "Spark configuration.")
+ writer.flush()
+ confStream.closeEntry()
+ } finally {
+ confStream.close()
}
+ confArchive
}
/**
@@ -453,7 +503,9 @@ private[spark] class Client(
/**
* Set up the environment for launching our ApplicationMaster container.
*/
- private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = {
+ private def setupLaunchEnv(
+ stagingDir: String,
+ pySparkArchives: Seq[String]): HashMap[String, String] = {
logInfo("Setting up the launch environment for our AM container")
val env = new HashMap[String, String]()
val extraCp = sparkConf.getOption("spark.driver.extraClassPath")
@@ -471,9 +523,6 @@ private[spark] class Client(
val renewalInterval = getTokenRenewalInterval(stagingDirPath)
sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString)
}
- // Set the environment variables to be passed on to the executors.
- distCacheMgr.setDistFilesEnv(env)
- distCacheMgr.setDistArchivesEnv(env)
// Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.*
val amEnvPrefix = "spark.yarn.appMasterEnv."
@@ -490,15 +539,32 @@ private[spark] class Client(
env("SPARK_YARN_USER_ENV") = userEnvs
}
- // if spark.submit.pyArchives is in sparkConf, append pyArchives to PYTHONPATH
- // that can be passed on to the ApplicationMaster and the executors.
- if (sparkConf.contains("spark.submit.pyArchives")) {
- var pythonPath = sparkConf.get("spark.submit.pyArchives")
- if (env.contains("PYTHONPATH")) {
- pythonPath = Seq(env.get("PYTHONPATH"), pythonPath).mkString(File.pathSeparator)
+ // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH
+ // of the container processes too. Add all non-.py files directly to PYTHONPATH.
+ //
+ // NOTE: the code currently does not handle .py files defined with a "local:" scheme.
+ val pythonPath = new ListBuffer[String]()
+ val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py"))
+ if (pyFiles.nonEmpty) {
+ pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ LOCALIZED_PYTHON_DIR)
+ }
+ (pySparkArchives ++ pyArchives).foreach { path =>
+ val uri = new URI(path)
+ if (uri.getScheme != LOCAL_SCHEME) {
+ pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ new Path(path).getName())
+ } else {
+ pythonPath += uri.getPath()
}
- env("PYTHONPATH") = pythonPath
- sparkConf.setExecutorEnv("PYTHONPATH", pythonPath)
+ }
+
+ // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors.
+ if (pythonPath.nonEmpty) {
+ val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath)
+ .mkString(YarnSparkHadoopUtil.getClassPathSeparator)
+ env("PYTHONPATH") = pythonPathStr
+ sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr)
}
// In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to
@@ -548,8 +614,19 @@ private[spark] class Client(
logInfo("Setting up container launch context for our AM")
val appId = newAppResponse.getApplicationId
val appStagingDir = getAppStagingDir(appId)
- val localResources = prepareLocalResources(appStagingDir)
- val launchEnv = setupLaunchEnv(appStagingDir)
+ val pySparkArchives =
+ if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) {
+ findPySparkArchives()
+ } else {
+ Nil
+ }
+ val launchEnv = setupLaunchEnv(appStagingDir, pySparkArchives)
+ val localResources = prepareLocalResources(appStagingDir, pySparkArchives)
+
+ // Set the environment variables to be passed on to the executors.
+ distCacheMgr.setDistFilesEnv(launchEnv)
+ distCacheMgr.setDistArchivesEnv(launchEnv)
+
val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
amContainer.setLocalResources(localResources)
amContainer.setEnvironment(launchEnv)
@@ -589,13 +666,6 @@ private[spark] class Client(
javaOpts += "-XX:CMSIncrementalDutyCycle=10"
}
- // Forward the Spark configuration to the application master / executors.
- // TODO: it might be nicer to pass these as an internal environment variable rather than
- // as Java options, due to complications with string parsing of nested quotes.
- for ((k, v) <- sparkConf.getAll) {
- javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v")
- }
-
// Include driver-specific java options if we are launching a driver
if (isClusterMode) {
val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions")
@@ -648,14 +718,8 @@ private[spark] class Client(
Nil
}
val primaryPyFile =
- if (args.primaryPyFile != null) {
- Seq("--primary-py-file", args.primaryPyFile)
- } else {
- Nil
- }
- val pyFiles =
- if (args.pyFiles != null) {
- Seq("--py-files", args.pyFiles)
+ if (isClusterMode && args.primaryPyFile != null) {
+ Seq("--primary-py-file", new Path(args.primaryPyFile).getName())
} else {
Nil
}
@@ -671,9 +735,6 @@ private[spark] class Client(
} else {
Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName
}
- if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
- args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs
- }
if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs
}
@@ -681,11 +742,13 @@ private[spark] class Client(
Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg))
}
val amArgs =
- Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++
+ Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++
userArgs ++ Seq(
"--executor-memory", args.executorMemory.toString + "m",
"--executor-cores", args.executorCores.toString,
- "--num-executors ", args.numExecutors.toString)
+ "--num-executors ", args.numExecutors.toString,
+ "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ LOCALIZED_CONF_DIR, SPARK_CONF_FILE))
// Command for the ApplicationMaster
val commands = prefixEnv ++ Seq(
@@ -764,6 +827,9 @@ private[spark] class Client(
case e: ApplicationNotFoundException =>
logError(s"Application $appId not found.")
return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED)
+ case NonFatal(e) =>
+ logError(s"Failed to contact YARN for application $appId.", e)
+ return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED)
}
val state = report.getYarnApplicationState
@@ -782,6 +848,7 @@ private[spark] class Client(
if (state == YarnApplicationState.FINISHED ||
state == YarnApplicationState.FAILED ||
state == YarnApplicationState.KILLED) {
+ cleanupStagingDir(appId)
return (state, report.getFinalApplicationStatus)
}
@@ -849,6 +916,22 @@ private[spark] class Client(
}
}
}
+
+ private def findPySparkArchives(): Seq[String] = {
+ sys.env.get("PYSPARK_ARCHIVES_PATH")
+ .map(_.split(",").toSeq)
+ .getOrElse {
+ val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator)
+ val pyArchivesFile = new File(pyLibPath, "pyspark.zip")
+ require(pyArchivesFile.exists(),
+ "pyspark.zip not found; cannot run pyspark application in YARN mode.")
+ val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip")
+ require(py4jFile.exists(),
+ "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.")
+ Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath())
+ }
+ }
+
}
object Client extends Logging {
@@ -899,8 +982,14 @@ object Client extends Logging {
// Distribution-defined classpath to add to processes
val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH"
- // Subdirectory where the user's hadoop config files will be placed.
- val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__"
+ // Subdirectory where the user's Spark and Hadoop config files will be placed.
+ val LOCALIZED_CONF_DIR = "__spark_conf__"
+
+ // Name of the file in the conf archive containing Spark configuration.
+ val SPARK_CONF_FILE = "__spark_conf__.properties"
+
+ // Subdirectory where the user's python files (not archives) will be placed.
+ val LOCALIZED_PYTHON_DIR = "__pyfiles__"
/**
* Find the user-defined Spark jar if configured, or return the jar containing this
@@ -1025,7 +1114,7 @@ object Client extends Logging {
if (isAM) {
addClasspathEntry(
YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR +
- LOCALIZED_HADOOP_CONF_DIR, env)
+ LOCALIZED_CONF_DIR, env)
}
if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 5653c9f14dc6..35e990602a6c 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -30,7 +30,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
var archives: String = null
var userJar: String = null
var userClass: String = null
- var pyFiles: String = null
+ var pyFiles: Seq[String] = Nil
var primaryPyFile: String = null
var primaryRFile: String = null
var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
@@ -98,6 +98,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
numExecutors = initialNumExecutors
}
+ principal = Option(principal)
+ .orElse(sparkConf.getOption("spark.yarn.principal"))
+ .orNull
+ keytab = Option(keytab)
+ .orElse(sparkConf.getOption("spark.yarn.keytab"))
+ .orNull
}
/**
@@ -222,7 +228,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
args = tail
case ("--py-files") :: value :: tail =>
- pyFiles = value
+ pyFiles = value.split(",")
args = tail
case ("--files") :: value :: tail =>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
index 4ca6c903fcf1..3d3a966960e9 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
* Add a resource to the list of distributed cache resources. This list can
* be sent to the ApplicationMaster and possibly the executors so that it can
* be downloaded into the Hadoop distributed cache for use by this application.
- * Adds the LocalResource to the localResources HashMap passed in and saves
+ * Adds the LocalResource to the localResources HashMap passed in and saves
* the stats of the resources to they can be sent to the executors and verified.
*
* @param fs FileSystem
* @param conf Configuration
* @param destPath path to the resource
* @param localResources localResource hashMap to insert the resource into
- * @param resourceType LocalResourceType
+ * @param resourceType LocalResourceType
* @param link link presented in the distributed cache to the destination
- * @param statCache cache to store the file/directory stats
+ * @param statCache cache to store the file/directory stats
* @param appMasterOnly Whether to only add the resource to the app master
*/
def addResource(
fs: FileSystem,
conf: Configuration,
- destPath: Path,
+ destPath: Path,
localResources: HashMap[String, LocalResource],
resourceType: LocalResourceType,
link: String,
@@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
amJarRsrc.setSize(destStatus.getLen())
if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name")
localResources(link) = amJarRsrc
-
+
if (!appMasterOnly) {
val uri = destPath.toUri()
val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link)
if (resourceType == LocalResourceType.FILE) {
- distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
+ distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
destStatus.getModificationTime().toString(), visibility.name())
} else {
- distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
+ distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
destStatus.getModificationTime().toString(), visibility.name())
}
}
@@ -96,11 +96,11 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
val (sizes, timeStamps, visibilities) = tupleValues.unzip3
if (keys.size > 0) {
env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
sizes.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
+ env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
visibilities.reduceLeft[String] { (acc, n) => acc + "," + n }
}
}
@@ -113,11 +113,11 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
val (sizes, timeStamps, visibilities) = tupleValues.unzip3
if (keys.size > 0) {
env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n }
env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") =
sizes.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
+ env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
visibilities.reduceLeft[String] { (acc, n) => acc + "," + n }
}
}
@@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = {
val stat = statCache.get(uri) match {
case Some(existstat) => existstat
- case None =>
+ case None =>
val newStat = fs.getFileStatus(new Path(uri))
statCache.put(uri, newStat)
newStat
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 9d04d241dae9..b0937083bc53 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -303,8 +303,8 @@ class ExecutorRunnable(
val address = container.getNodeHttpAddress
val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user"
- env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0"
- env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0"
+ env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=-4096"
+ env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096"
}
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v }
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 21193e7c625e..940873fbd046 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -146,11 +146,16 @@ private[yarn] class YarnAllocator(
* Request as many executors from the ResourceManager as needed to reach the desired total. If
* the requested total is smaller than the current number of running executors, no executors will
* be killed.
+ *
+ * @return Whether the new requested total is different than the old value.
*/
- def requestTotalExecutors(requestedTotal: Int): Unit = synchronized {
+ def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized {
if (requestedTotal != targetNumExecutors) {
logInfo(s"Driver requested a total number of $requestedTotal executor(s).")
targetNumExecutors = requestedTotal
+ true
+ } else {
+ false
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 5e6531895c7b..68d01c17ef72 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -144,9 +144,9 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
}
object YarnSparkHadoopUtil {
- // Additional memory overhead
+ // Additional memory overhead
// 10% was arrived at experimentally. In the interest of minimizing memory waste while covering
- // the common cases. Memory overhead tends to grow with container size.
+ // the common cases. Memory overhead tends to grow with container size.
val MEMORY_OVERHEAD_FACTOR = 0.10
val MEMORY_OVERHEAD_MIN = 384
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 99c05329b4d7..1c8d7ec57635 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -76,7 +76,8 @@ private[spark] class YarnClientSchedulerBackend(
("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"),
("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"),
("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"),
- ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue")
+ ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"),
+ ("--py-files", null, "spark.submit.pyFiles")
)
// Warn against the following deprecated environment variables: env var -> suggestion
val deprecatedEnvVars = Map(
@@ -86,7 +87,7 @@ private[spark] class YarnClientSchedulerBackend(
optionTuples.foreach { case (optionName, envVar, sparkProp) =>
if (sc.getConf.contains(sparkProp)) {
extraArgs += (optionName, sc.getConf.get(sparkProp))
- } else if (System.getenv(envVar) != null) {
+ } else if (envVar != null && System.getenv(envVar) != null) {
extraArgs += (optionName, System.getenv(envVar))
if (deprecatedEnvVars.contains(envVar)) {
logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.")
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
index 80b57d1355a3..804dfecde786 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn
import java.net.URI
-import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito.when
@@ -36,16 +35,18 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
import scala.collection.mutable.HashMap
import scala.collection.mutable.Map
+import org.apache.spark.SparkFunSuite
-class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
+
+class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar {
class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
- override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
LocalResourceVisibility = {
LocalResourceVisibility.PRIVATE
}
}
-
+
test("test getFileStatus empty") {
val distMgr = new ClientDistributedCacheManager()
val fs = mock[FileSystem]
@@ -60,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val distMgr = new ClientDistributedCacheManager()
val fs = mock[FileSystem]
val uri = new URI("/tmp/testing")
- val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
+ val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
null, new Path("/tmp/testing"))
when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus)
@@ -77,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
- distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
statCache, false)
val resource = localResources("link")
assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
@@ -100,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
// add another one and verify both there and order correct
- val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
null, new Path("/tmp/testing2"))
val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
- distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
+ distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
statCache, false)
val resource2 = localResources("link2")
assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
@@ -116,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val env2 = new HashMap[String, String]()
distMgr.setDistFilesEnv(env2)
val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
- val files = env2("SPARK_YARN_CACHE_FILES").split(',')
+ val files = env2("SPARK_YARN_CACHE_FILES").split(',')
val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',')
assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
@@ -140,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
intercept[Exception] {
- distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
statCache, false)
}
assert(localResources.get("link") === None)
@@ -154,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
val localResources = HashMap[String, LocalResource]()
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
- val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
null, new Path("/tmp/testing"))
when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
- distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
statCache, true)
val resource = localResources("link")
assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
@@ -188,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
val localResources = HashMap[String, LocalResource]()
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
- val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
null, new Path("/tmp/testing"))
when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
- distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
statCache, false)
val resource = localResources("link")
assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 6da3e82acdb1..4ec976aa3138 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -33,12 +33,12 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.mockito.Matchers._
import org.mockito.Mockito._
-import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfterAll, Matchers}
-import org.apache.spark.{SparkException, SparkConf}
+import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.util.Utils
-class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll {
+class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll {
override def beforeAll(): Unit = {
System.setProperty("SPARK_YARN_MODE", "true")
@@ -113,7 +113,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll {
Environment.PWD.$()
}
cp should contain(pwdVar)
- cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}")
+ cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}")
cp should not contain (Client.SPARK_JAR)
cp should not contain (Client.APP_JAR)
}
@@ -129,7 +129,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll {
val tempDir = Utils.createTempDir()
try {
- client.prepareLocalResources(tempDir.getAbsolutePath())
+ client.prepareLocalResources(tempDir.getAbsolutePath(), Nil)
sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER))
// The non-local path should be propagated by name only, since it will end up in the app's
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index b343cbb0c756..7509000771d9 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -26,13 +26,13 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
-import org.apache.spark.SecurityManager
+import org.apache.spark.{SecurityManager, SparkFunSuite}
import org.apache.spark.SparkConf
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
import org.apache.spark.deploy.yarn.YarnAllocator._
import org.apache.spark.scheduler.SplitInfo
-import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfterEach, Matchers}
class MockResolver extends DNSToSwitchMapping {
@@ -46,7 +46,7 @@ class MockResolver extends DNSToSwitchMapping {
def reloadCachedMappings(names: JList[String]) {}
}
-class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach {
+class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach {
val conf = new Configuration()
conf.setClass(
CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY,
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index dcaeb2e43ff4..a0f25ba45006 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -18,21 +18,21 @@
package org.apache.spark.deploy.yarn
import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.net.URL
import java.util.Properties
import java.util.concurrent.TimeUnit
import scala.collection.JavaConversions._
import scala.collection.mutable
-import scala.io.Source
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.ByteStreams
import com.google.common.io.Files
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.server.MiniYARNCluster
-import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfterAll, Matchers}
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils}
+import org.apache.spark._
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
SparkListenerExecutorAdded}
@@ -43,7 +43,7 @@ import org.apache.spark.util.Utils
* applications, and require the Spark assembly to be built before they can be successfully
* run.
*/
-class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging {
+class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging {
// log4j configuration for the YARN containers, so that their output is collected
// by YARN instead of trying to overwrite unit-tests.log.
@@ -56,6 +56,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
""".stripMargin
private val TEST_PYFILE = """
+ |import mod1, mod2
|import sys
|from operator import add
|
@@ -67,7 +68,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
| sc = SparkContext(conf=SparkConf())
| status = open(sys.argv[1],'w')
| result = "failure"
- | rdd = sc.parallelize(range(10))
+ | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func())
| cnt = rdd.count()
| if cnt == 10:
| result = "success"
@@ -76,6 +77,11 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
| sc.stop()
""".stripMargin
+ private val TEST_PYMODULE = """
+ |def func():
+ | return 42
+ """.stripMargin
+
private var yarnCluster: MiniYARNCluster = _
private var tempDir: File = _
private var fakeSparkJar: File = _
@@ -124,7 +130,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
- hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR)
+ hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
assert(hadoopConfDir.mkdir())
File.createTempFile("token", ".txt", hadoopConfDir)
}
@@ -151,26 +157,12 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
}
}
- // Enable this once fix SPARK-6700
- test("run Python application in yarn-cluster mode") {
- val primaryPyFile = new File(tempDir, "test.py")
- Files.write(TEST_PYFILE, primaryPyFile, UTF_8)
- val pyFile = new File(tempDir, "test2.py")
- Files.write(TEST_PYFILE, pyFile, UTF_8)
- var result = File.createTempFile("result", null, tempDir)
+ test("run Python application in yarn-client mode") {
+ testPySpark(true)
+ }
- // The sbt assembly does not include pyspark / py4j python dependencies, so we need to
- // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala.
- val sparkHome = sys.props("spark.test.home")
- val extraConf = Map(
- "spark.executorEnv.SPARK_HOME" -> sparkHome,
- "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome)
-
- runSpark(false, primaryPyFile.getAbsolutePath(),
- sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()),
- appArgs = Seq(result.getAbsolutePath()),
- extraConf = extraConf)
- checkResult(result)
+ test("run Python application in yarn-cluster mode") {
+ testPySpark(false)
}
test("user class path first in client mode") {
@@ -188,6 +180,33 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
checkResult(result)
}
+ private def testPySpark(clientMode: Boolean): Unit = {
+ val primaryPyFile = new File(tempDir, "test.py")
+ Files.write(TEST_PYFILE, primaryPyFile, UTF_8)
+
+ val moduleDir =
+ if (clientMode) {
+ // In client-mode, .py files added with --py-files are not visible in the driver.
+ // This is something that the launcher library would have to handle.
+ tempDir
+ } else {
+ val subdir = new File(tempDir, "pyModules")
+ subdir.mkdir()
+ subdir
+ }
+ val pyModule = new File(moduleDir, "mod1.py")
+ Files.write(TEST_PYMODULE, pyModule, UTF_8)
+
+ val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir)
+ val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",")
+ val result = File.createTempFile("result", null, tempDir)
+
+ runSpark(clientMode, primaryPyFile.getAbsolutePath(),
+ sparkArgs = Seq("--py-files", pyFiles),
+ appArgs = Seq(result.getAbsolutePath()))
+ checkResult(result)
+ }
+
private def testUseClassPathFirst(clientMode: Boolean): Unit = {
// Create a jar file that contains a different version of "test.resource".
val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
@@ -326,7 +345,7 @@ private object YarnClusterDriver extends Logging with Matchers {
var result = "failure"
try {
val data = sc.parallelize(1 to 4, 4).collect().toSet
- assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
data should be (Set(1, 2, 3, 4))
result = "success"
} finally {
@@ -344,18 +363,20 @@ private object YarnClusterDriver extends Logging with Matchers {
assert(info.logUrlMap.nonEmpty)
}
- // If we are running in yarn-cluster mode, verify that driver logs are downloadable.
+ // If we are running in yarn-cluster mode, verify that driver logs links and present and are
+ // in the expected format.
if (conf.get("spark.master") == "yarn-cluster") {
assert(listener.driverLogs.nonEmpty)
val driverLogs = listener.driverLogs.get
assert(driverLogs.size === 2)
assert(driverLogs.containsKey("stderr"))
assert(driverLogs.containsKey("stdout"))
- val stderr = driverLogs("stderr") // YARN puts everything in stderr.
- val lines = Source.fromURL(stderr).getLines()
- // Look for a line that contains YarnClusterSchedulerBackend, since that is guaranteed in
- // cluster mode.
- assert(lines.exists(_.contains("YarnClusterSchedulerBackend")))
+ val urlStr = driverLogs("stderr")
+ // Ensure that this is a valid URL, else this will throw an exception
+ new URL(urlStr)
+ val containerId = YarnSparkHadoopUtil.get.getContainerId
+ val user = Utils.getCurrentUserName()
+ assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=0"))
}
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
index e10b985c3c23..49bee0866dd4 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -25,15 +25,15 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.yarn.api.ApplicationConstants
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
import org.apache.hadoop.yarn.api.records.ApplicationAccessType
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.util.Utils
-class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging {
+class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging {
val hasBash =
try {