diff --git a/.gitignore b/.gitignore index 1d91b43c23fa..903297db9690 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ R-unit-tests.log R/unit-tests.out R/cran-check.out R/pkg/vignettes/sparkr-vignettes.html +R/pkg/tests/fulltests/Rplots.pdf build/*.jar build/apache-maven* build/scala* @@ -46,6 +47,8 @@ dev/pr-deps/ dist/ docs/_site docs/api +sql/docs +sql/site lib_managed/ lint-r-report.log log/ diff --git a/LICENSE b/LICENSE index c21032a1fd27..39fe0dc46238 100644 --- a/LICENSE +++ b/LICENSE @@ -249,11 +249,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/R/README.md b/R/README.md index 4c40c5963db7..1152b1e8e5f9 100644 --- a/R/README.md +++ b/R/README.md @@ -66,11 +66,7 @@ To run one of them, use `./bin/spark-submit `. For example: ```bash ./bin/spark-submit examples/src/main/r/dataframe.R ``` -You can also run the unit tests for SparkR by running. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: -```bash -R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' -./R/run-tests.sh -``` +You can run R unit tests by following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests). ### Running on YARN diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 9ca7e58e20cd..124bc631be9c 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -34,10 +34,9 @@ To run the SparkR unit tests on Windows, the following steps are required —ass 4. Set the environment variable `HADOOP_HOME` to the full path to the newly created `hadoop` directory. -5. Run unit tests for SparkR by running the command below. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: +5. Run unit tests for SparkR by running the command below. You need to install the needed packages following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests) first: ``` - R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R ``` diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index f12f8c275a98..18b2db69db8f 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -6,3 +6,4 @@ ^README\.Rmd$ ^src-native$ ^html$ +^tests/fulltests/* diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 879c1f80f2c5..d1c846c04827 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package -Version: 2.2.0 +Version: 2.3.0 Title: R Frontend for Apache Spark -Description: The SparkR package provides an R Frontend for Apache Spark. +Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index e8de34d9371a..3fc756b9ef40 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -63,6 +63,7 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", + "spark.decisionTree", "spark.randomForest", "spark.gbt", "spark.bisectingKmeans", @@ -74,7 +75,8 @@ exportMethods("glm", # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", - "cancelJobGroup") + "cancelJobGroup", + "setJobDescription") # Export Utility methods export("setLogLevel") @@ -84,6 +86,7 @@ exportClasses("SparkDataFrame") exportMethods("arrange", "as.data.frame", "attach", + "broadcast", "cache", "checkpoint", "coalesce", @@ -123,6 +126,7 @@ exportMethods("arrange", "group_by", "groupBy", "head", + "hint", "insertInto", "intersect", "isLocal", @@ -165,6 +169,7 @@ exportMethods("arrange", "transform", "union", "unionAll", + "unionByName", "unique", "unpersist", "where", @@ -249,12 +254,15 @@ exportMethods("%<=>%", "getField", "getItem", "greatest", + "grouping_bit", + "grouping_id", "hex", "histogram", "hour", "hypot", "ifelse", "initcap", + "input_file_name", "instr", "isNaN", "isNotNull", @@ -279,6 +287,8 @@ exportMethods("%<=>%", "lower", "lpad", "ltrim", + "map_keys", + "map_values", "max", "md5", "mean", @@ -351,6 +361,7 @@ exportMethods("%<=>%", "to_utc_timestamp", "translate", "trim", + "trunc", "unbase64", "unhex", "unix_timestamp", @@ -409,6 +420,8 @@ export("as.DataFrame", "print.summary.GeneralizedLinearRegressionModel", "read.ml", "print.summary.KSTest", + "print.summary.DecisionTreeRegressionModel", + "print.summary.DecisionTreeClassificationModel", "print.summary.RandomForestRegressionModel", "print.summary.RandomForestClassificationModel", "print.summary.GBTRegressionModel", @@ -419,6 +432,7 @@ export("structField", "structField.character", "print.structField", "structType", + "structType.character", "structType.jobj", "structType.structField", "print.structType") @@ -447,11 +461,14 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.DecisionTreeRegressionModel) +S3method(print, summary.DecisionTreeClassificationModel) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) S3method(print, summary.GBTRegressionModel) S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) +S3method(structType, character) S3method(structType, jobj) S3method(structType, structField) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 7e57ba6287bb..0728141fa483 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -549,7 +549,7 @@ setMethod("registerTempTable", #' sparkR.session() #' df <- read.df(path, "parquet") #' df2 <- read.df(path2, "parquet") -#' createOrReplaceTempView(df, "table1") +#' saveAsTable(df, "table1") #' insertInto(df2, "table1", overwrite = TRUE) #'} #' @note insertInto since 1.4.0 @@ -593,7 +593,7 @@ setMethod("cache", #' #' Persist this SparkDataFrame with the specified storage level. For details of the #' supported storage levels, refer to -#' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. +#' \url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}. #' #' @param x the SparkDataFrame to persist. #' @param newLevel storage level chosen for the persistance. See available options in @@ -986,10 +986,10 @@ setMethod("unique", #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction -#' @param seed Randomness seed value +#' @param seed Randomness seed value. Default is a random seed. #' #' @family SparkDataFrame functions -#' @aliases sample,SparkDataFrame,logical,numeric-method +#' @aliases sample,SparkDataFrame-method #' @rdname sample #' @name sample #' @export @@ -998,33 +998,47 @@ setMethod("unique", #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) +#' collect(sample(df, fraction = 0.5)) #' collect(sample(df, FALSE, 0.5)) -#' collect(sample(df, TRUE, 0.5)) +#' collect(sample(df, TRUE, 0.5, seed = 3)) #'} #' @note sample since 1.4.0 setMethod("sample", - signature(x = "SparkDataFrame", withReplacement = "logical", - fraction = "numeric"), - function(x, withReplacement, fraction, seed) { - if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + signature(x = "SparkDataFrame"), + function(x, withReplacement = FALSE, fraction, seed) { + if (!is.numeric(fraction)) { + stop(paste("fraction must be numeric; however, got", class(fraction))) + } + if (!is.logical(withReplacement)) { + stop(paste("withReplacement must be logical; however, got", class(withReplacement))) + } + if (!missing(seed)) { + if (is.null(seed)) { + stop("seed must not be NULL or NA; however, got NULL") + } + if (is.na(seed)) { + stop("seed must not be NULL or NA; however, got NA") + } + # TODO : Figure out how to send integer as java.lang.Long to JVM so # we can send seed as an argument through callJMethod - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + sdf <- handledCallJMethod(x@sdf, "sample", as.logical(withReplacement), + as.numeric(fraction), as.integer(seed)) } else { - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + sdf <- handledCallJMethod(x@sdf, "sample", + as.logical(withReplacement), as.numeric(fraction)) } dataFrame(sdf) }) #' @rdname sample -#' @aliases sample_frac,SparkDataFrame,logical,numeric-method +#' @aliases sample_frac,SparkDataFrame-method #' @name sample_frac #' @note sample_frac since 1.4.0 setMethod("sample_frac", - signature(x = "SparkDataFrame", withReplacement = "logical", - fraction = "numeric"), - function(x, withReplacement, fraction, seed) { + signature(x = "SparkDataFrame"), + function(x, withReplacement = FALSE, fraction, seed) { sample(x, withReplacement, fraction, seed) }) @@ -1125,7 +1139,8 @@ setMethod("dim", #' path <- "path/to/file.json" #' df <- read.json(path) #' collected <- collect(df) -#' firstName <- collected[[1]]$name +#' class(collected) +#' firstName <- names(collected)[1] #' } #' @note collect since 1.4.0 setMethod("collect", @@ -1390,6 +1405,10 @@ setMethod("summarize", }) dapplyInternal <- function(x, func, schema) { + if (is.character(schema)) { + schema <- structType(schema) + } + packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) @@ -1407,6 +1426,8 @@ dapplyInternal <- function(x, func, schema) { dataFrame(sdf) } +setClassUnion("characterOrstructType", c("character", "structType")) + #' dapply #' #' Apply a function to each partition of a SparkDataFrame. @@ -1417,10 +1438,11 @@ dapplyInternal <- function(x, func, schema) { #' to each partition will be passed. #' The output of func should be a R data.frame. #' @param schema The schema of the resulting SparkDataFrame after the function is applied. -#' It must match the output of func. +#' It must match the output of func. Since Spark 2.3, the DDL-formatted string +#' is also supported for the schema. #' @family SparkDataFrame functions #' @rdname dapply -#' @aliases dapply,SparkDataFrame,function,structType-method +#' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} #' @export @@ -1443,6 +1465,17 @@ dapplyInternal <- function(x, func, schema) { #' y <- cbind(y, y[1] + 1L) #' }, #' schema) +#' +#' # The schema also can be specified in a DDL-formatted string. +#' schema <- "a INT, d DOUBLE, c STRING, d INT" +#' df1 <- dapply( +#' df, +#' function(x) { +#' y <- x[x[1] > 1, ] +#' y <- cbind(y, y[1] + 1L) +#' }, +#' schema) +#' #' collect(df1) #' # the result #' # a b c d @@ -1451,7 +1484,7 @@ dapplyInternal <- function(x, func, schema) { #' } #' @note dapply since 2.0.0 setMethod("dapply", - signature(x = "SparkDataFrame", func = "function", schema = "structType"), + signature(x = "SparkDataFrame", func = "function", schema = "characterOrstructType"), function(x, func, schema) { dapplyInternal(x, func, schema) }) @@ -1521,6 +1554,7 @@ setMethod("dapplyCollect", #' @param schema the schema of the resulting SparkDataFrame after the function is applied. #' The schema must match to output of \code{func}. It has to be defined for each #' output column with preferred output column name and corresponding data type. +#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @return A SparkDataFrame. #' @family SparkDataFrame functions #' @aliases gapply,SparkDataFrame-method @@ -1540,7 +1574,7 @@ setMethod("dapplyCollect", #' #' Here our output contains three columns, the key which is a combination of two #' columns with data types integer and string and the mean which is a double. -#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' schema <- structType(structField("a", "integer"), structField("c", "string"), #' structField("avg", "double")) #' result <- gapply( #' df, @@ -1549,6 +1583,15 @@ setMethod("dapplyCollect", #' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) #' }, schema) #' +#' The schema also can be specified in a DDL-formatted string. +#' schema <- "a INT, c STRING, avg DOUBLE" +#' result <- gapply( +#' df, +#' c("a", "c"), +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' }, schema) +#' #' We can also group the data and afterwards call gapply on GroupedData. #' For Example: #' gdf <- group_by(df, "a", "c") @@ -2645,6 +2688,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' Input SparkDataFrames can have different schemas (names and data types). #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. +#' Also as standard in SQL, this function resolves columns by position (not by name). #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame @@ -2653,7 +2697,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' @rdname union #' @name union #' @aliases union,SparkDataFrame,SparkDataFrame-method -#' @seealso \link{rbind} +#' @seealso \link{rbind} \link{unionByName} #' @export #' @examples #'\dontrun{ @@ -2684,6 +2728,40 @@ setMethod("unionAll", union(x, y) }) +#' Return a new SparkDataFrame containing the union of rows, matched by column names +#' +#' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame +#' and another SparkDataFrame. This is different from \code{union} function, and both +#' \code{UNION ALL} and \code{UNION DISTINCT} in SQL as column positions are not taken +#' into account. Input SparkDataFrames can have different data types in the schema. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. +#' This function resolves columns by name (not by position). +#' +#' @param x A SparkDataFrame +#' @param y A SparkDataFrame +#' @return A SparkDataFrame containing the result of the union. +#' @family SparkDataFrame functions +#' @rdname unionByName +#' @name unionByName +#' @aliases unionByName,SparkDataFrame,SparkDataFrame-method +#' @seealso \link{rbind} \link{union} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- select(createDataFrame(mtcars), "carb", "am", "gear") +#' df2 <- select(createDataFrame(mtcars), "am", "gear", "carb") +#' head(unionByName(df1, df2)) +#' } +#' @note unionByName since 2.3.0 +setMethod("unionByName", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + unioned <- callJMethod(x@sdf, "unionByName", y@sdf) + dataFrame(unioned) + }) + #' Union two or more SparkDataFrames #' #' Union two or more SparkDataFrames by row. As in R's \code{rbind}, this method @@ -2700,7 +2778,7 @@ setMethod("unionAll", #' @aliases rbind,SparkDataFrame-method #' @rdname rbind #' @name rbind -#' @seealso \link{union} +#' @seealso \link{union} \link{unionByName} #' @export #' @examples #'\dontrun{ @@ -2814,7 +2892,7 @@ setMethod("except", #' path <- "path/to/file.json" #' df <- read.json(path) #' write.df(df, "myfile", "parquet", "overwrite") -#' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) +#' saveDF(df, parquetPath2, "parquet", mode = "append", mergeSchema = TRUE) #' } #' @note write.df since 1.4.0 setMethod("write.df", @@ -2900,7 +2978,7 @@ setMethod("saveAsTable", invisible(callJMethod(write, "saveAsTable", tableName)) }) -#' summary +#' describe #' #' Computes statistics for numeric and string columns. #' If no columns are given, this function computes statistics for all numerical or string columns. @@ -2911,7 +2989,7 @@ setMethod("saveAsTable", #' @return A SparkDataFrame. #' @family SparkDataFrame functions #' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method -#' @rdname summary +#' @rdname describe #' @name describe #' @export #' @examples @@ -2923,6 +3001,7 @@ setMethod("saveAsTable", #' describe(df, "col1") #' describe(df, "col1", "col2") #' } +#' @seealso See \link{summary} for expanded statistics and control over which statistics to compute. #' @note describe(SparkDataFrame, character) since 1.4.0 setMethod("describe", signature(x = "SparkDataFrame", col = "character"), @@ -2932,7 +3011,7 @@ setMethod("describe", dataFrame(sdf) }) -#' @rdname summary +#' @rdname describe #' @name describe #' @aliases describe,SparkDataFrame-method #' @note describe(SparkDataFrame) since 1.4.0 @@ -2943,15 +3022,50 @@ setMethod("describe", dataFrame(sdf) }) +#' summary +#' +#' Computes specified statistics for numeric and string columns. Available statistics are: +#' \itemize{ +#' \item count +#' \item mean +#' \item stddev +#' \item min +#' \item max +#' \item arbitrary approximate percentiles specified as a percentage (eg, "75%") +#' } +#' If no statistics are given, this function computes count, mean, stddev, min, +#' approximate quartiles (percentiles at 25%, 50%, and 75%), and max. +#' This function is meant for exploratory data analysis, as we make no guarantee about the +#' backward compatibility of the schema of the resulting Dataset. If you want to +#' programmatically compute summary statistics, use the \code{agg} function instead. +#' +#' #' @param object a SparkDataFrame to be summarized. +#' @param ... (optional) statistics to be computed for all columns. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions #' @rdname summary #' @name summary #' @aliases summary,SparkDataFrame-method +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' summary(df) +#' summary(df, "min", "25%", "75%", "max") +#' summary(select(df, "age", "height")) +#' } #' @note summary(SparkDataFrame) since 1.5.0 +#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for previous defaults. +#' @seealso \link{describe} setMethod("summary", signature(object = "SparkDataFrame"), function(object, ...) { - describe(object) + statisticsList <- list(...) + sdf <- callJMethod(object@sdf, "summary", statisticsList) + dataFrame(sdf) }) @@ -3097,8 +3211,8 @@ setMethod("fillna", #' @family SparkDataFrame functions #' @aliases as.data.frame,SparkDataFrame-method #' @rdname as.data.frame -#' @examples \dontrun{ -#' +#' @examples +#' \dontrun{ #' irisDF <- createDataFrame(iris) #' df <- as.data.frame(irisDF[irisDF$Species == "setosa", ]) #' } @@ -3175,7 +3289,8 @@ setMethod("with", #' @aliases str,SparkDataFrame-method #' @family SparkDataFrame functions #' @param object a SparkDataFrame -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' # Create a SparkDataFrame from the Iris dataset #' irisDF <- createDataFrame(iris) #' @@ -3667,8 +3782,8 @@ setMethod("checkpoint", #' mean(cube(df, "cyl", "gear", "am"), "mpg") #' #' # Following calls are equivalent -#' agg(cube(carsDF), mean(carsDF$mpg)) -#' agg(carsDF, mean(carsDF$mpg)) +#' agg(cube(df), mean(df$mpg)) +#' agg(df, mean(df$mpg)) #' } #' @note cube since 2.3.0 #' @seealso \link{agg}, \link{groupBy}, \link{rollup} @@ -3702,8 +3817,8 @@ setMethod("cube", #' mean(rollup(df, "cyl", "gear", "am"), "mpg") #' #' # Following calls are equivalent -#' agg(rollup(carsDF), mean(carsDF$mpg)) -#' agg(carsDF, mean(carsDF$mpg)) +#' agg(rollup(df), mean(df$mpg)) +#' agg(df, mean(df$mpg)) #' } #' @note rollup since 2.3.0 #' @seealso \link{agg}, \link{cube}, \link{groupBy} @@ -3715,3 +3830,86 @@ setMethod("rollup", sgd <- callJMethod(x@sdf, "rollup", jcol) groupedData(sgd) }) + +#' hint +#' +#' Specifies execution plan hint and return a new SparkDataFrame. +#' +#' @param x a SparkDataFrame. +#' @param name a name of the hint. +#' @param ... optional parameters for the hint. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases hint,SparkDataFrame,character-method +#' @rdname hint +#' @name hint +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl)) +#' } +#' @note hint since 2.2.0 +setMethod("hint", + signature(x = "SparkDataFrame", name = "character"), + function(x, name, ...) { + parameters <- list(...) + stopifnot(all(sapply(parameters, is.character))) + jdf <- callJMethod(x@sdf, "hint", name, parameters) + dataFrame(jdf) + }) + +#' alias +#' +#' @aliases alias,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname alias +#' @name alias +#' @export +#' @examples +#' \dontrun{ +#' df <- alias(createDataFrame(mtcars), "mtcars") +#' avg_mpg <- alias(agg(groupBy(df, df$cyl), avg(df$mpg)), "avg_mpg") +#' +#' head(select(df, column("mtcars.mpg"))) +#' head(join(df, avg_mpg, column("mtcars.cyl") == column("avg_mpg.cyl"))) +#' } +#' @note alias(SparkDataFrame) since 2.3.0 +setMethod("alias", + signature(object = "SparkDataFrame"), + function(object, data) { + stopifnot(is.character(data)) + sdf <- callJMethod(object@sdf, "alias", data) + dataFrame(sdf) + }) + +#' broadcast +#' +#' Return a new SparkDataFrame marked as small enough for use in broadcast joins. +#' +#' Equivalent to \code{hint(x, "broadcast")}. +#' +#' @param x a SparkDataFrame. +#' @return a SparkDataFrame. +#' +#' @aliases broadcast,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname broadcast +#' @name broadcast +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, broadcast(avg_mpg), df$cyl == avg_mpg$cyl)) +#' } +#' @note broadcast since 2.3.0 +setMethod("broadcast", + signature(x = "SparkDataFrame"), + function(x) { + sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) + dataFrame(sdf) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 7ad3993e9ecb..15ca212acf87 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -227,7 +227,7 @@ setMethod("cacheRDD", #' #' Persist this RDD with the specified storage level. For details of the #' supported storage levels, refer to -#'\url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. +#'\url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}. #' #' @param x The RDD to persist #' @param newLevel The new storage level to be assigned diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index f5c3a749fe0a..3b7f71bbbffb 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -334,7 +334,7 @@ setMethod("toDF", signature(x = "RDD"), #' #' Loads a JSON file, returning the result as a SparkDataFrame #' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} -#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to +#' ) is supported. For JSON (one record per file), set a named property \code{multiLine} to #' \code{TRUE}. #' It goes through the entire dataset once to determine the schema. #' @@ -348,7 +348,7 @@ setMethod("toDF", signature(x = "RDD"), #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) -#' df <- read.json(path, wholeFile = TRUE) +#' df <- read.json(path, multiLine = TRUE) #' df <- jsonFile(path) #' } #' @name read.json @@ -584,7 +584,7 @@ tableToDF <- function(tableName) { #' #' @param path The path of files to load #' @param source The name of external data source -#' @param schema The data schema defined in structType +#' @param schema The data schema defined in structType or a DDL-formatted string. #' @param na.strings Default string value for NA when source is "csv" #' @param ... additional external data source specific named properties. #' @return SparkDataFrame @@ -598,8 +598,10 @@ tableToDF <- function(tableName) { #' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE) +#' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") +#' stringSchema <- "name STRING, info MAP" +#' df4 <- read.df(mapTypeJsonPath, "json", stringSchema, multiLine = TRUE) #' } #' @name read.df #' @method read.df default @@ -623,14 +625,19 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string if (source == "csv" && is.null(options[["nullValue"]])) { options[["nullValue"]] <- na.strings } + read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, schema$jobj, options) - } else { - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, options) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") dataFrame(sdf) } @@ -717,8 +724,8 @@ read.jdbc <- function(url, tableName, #' "spark.sql.sources.default" will be used. #' #' @param source The name of external data source -#' @param schema The data schema defined in structType, this is required for file-based streaming -#' data source +#' @param schema The data schema defined in structType or a DDL-formatted string, this is +#' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for #' file-based streaming data source #' @return SparkDataFrame @@ -733,6 +740,8 @@ read.jdbc <- function(url, tableName, #' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") #' #' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' stringSchema <- "name STRING, info MAP" +#' df1 <- read.stream("json", path = jsonDir, schema = stringSchema, maxFilesPerTrigger = 1) #' } #' @name read.stream #' @note read.stream since 2.2.0 @@ -750,10 +759,15 @@ read.stream <- function(source = NULL, schema = NULL, ...) { read <- callJMethod(sparkSession, "readStream") read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - read <- callJMethod(read, "schema", schema$jobj) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } read <- callJMethod(read, "options", options) sdf <- handledCallJMethod(read, "load") - dataFrame(callJMethod(sdf, "toDF")) + dataFrame(sdf) } diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index 4ac83c29c6f7..81beac9ea992 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -203,7 +203,8 @@ setMethod("rangeBetween", #' @aliases over,Column,WindowSpec-method #' @family colum_func #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' #' # Partition by am (transmission) and order by hp (horsepower) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 147ee4b6887b..a5c2ea81f249 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -130,19 +130,20 @@ createMethods <- function() { createMethods() -#' alias -#' -#' Set a new name for a column -#' -#' @param object Column to rename -#' @param data new name to use -#' #' @rdname alias #' @name alias #' @aliases alias,Column-method #' @family colum_func #' @export -#' @note alias since 1.4.0 +#' @examples +#' \dontrun{ +#' df <- createDataFrame(iris) +#' +#' head(select( +#' df, alias(df$Sepal_Length, "slength"), alias(df$Petal_Length, "plength") +#' )) +#' } +#' @note alias(Column) since 1.4.0 setMethod("alias", signature(object = "Column"), function(object, data) { @@ -244,7 +245,8 @@ setMethod("between", signature(x = "Column"), #' @family colum_func #' @aliases cast,Column-method #' -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' cast(df$age, "string") #' } #' @note cast since 1.4.0 diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 50856e3d9856..8349b57a30a9 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -258,7 +258,7 @@ includePackage <- function(sc, pkg) { #' #' # Large Matrix object that we want to broadcast #' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) -#' randomMatBr <- broadcast(sc, randomMat) +#' randomMatBr <- broadcastRDD(sc, randomMat) #' #' # Use the broadcast variable inside the function #' useBroadcast <- function(x) { @@ -266,7 +266,7 @@ includePackage <- function(sc, pkg) { #' } #' sumRDD <- lapply(rdd, useBroadcast) #'} -broadcast <- function(sc, object) { +broadcastRDD <- function(sc, object) { objName <- as.character(substitute(object)) serializedObj <- serialize(object, connection = NULL) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f9687d680e7a..9f286263c216 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -18,23 +18,233 @@ #' @include generics.R column.R NULL -#' lit +#' Aggregate functions for Column operations #' -#' A new \linkS4class{Column} is created to represent the literal value. -#' If the parameter is a \linkS4class{Column}, it is returned unchanged. +#' Aggregate functions defined for \code{Column}. #' -#' @param x a literal value or a Column. -#' @family normal_funcs -#' @rdname lit -#' @name lit +#' @param x Column to compute on. +#' @param y,na.rm,use currently not used. +#' @param ... additional argument(s). For example, it could be used to pass additional Columns. +#' @name column_aggregate_functions +#' @rdname column_aggregate_functions +#' @family aggregate functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} +NULL + +#' Date time functions for Column operations +#' +#' Date time functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}. +#' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse +#' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string +#' to use to specify the truncation method. For example, "year", "yyyy", "yy" for +#' truncate by year, or "month", "mon", "mm" for truncate by month. +#' @param ... additional argument(s). +#' @name column_datetime_functions +#' @rdname column_datetime_functions +#' @family data time functions +#' @examples +#' \dontrun{ +#' dts <- c("2005-01-02 18:47:22", +#' "2005-12-24 16:30:58", +#' "2005-10-28 07:30:05", +#' "2005-12-28 07:01:05", +#' "2006-01-24 00:01:10") +#' y <- c(2.0, 2.2, 3.4, 2.5, 1.8) +#' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} +NULL + +#' Date time arithmetic functions for Column operations +#' +#' Date time arithmetic functions defined for \code{Column}. +#' +#' @param y Column to compute on. +#' @param x For class \code{Column}, it is the column used to perform arithmetic operations +#' with column \code{y}. For class \code{numeric}, it is the number of months or +#' days to be added to or subtracted from \code{y}. For class \code{character}, it is +#' \itemize{ +#' \item \code{date_format}: date format specification. +#' \item \code{from_utc_timestamp}, \code{to_utc_timestamp}: time zone to use. +#' \item \code{next_day}: day of the week string. +#' } +#' +#' @name column_datetime_diff_functions +#' @rdname column_datetime_diff_functions +#' @family data time functions +#' @examples +#' \dontrun{ +#' dts <- c("2005-01-02 18:47:22", +#' "2005-12-24 16:30:58", +#' "2005-10-28 07:30:05", +#' "2005-12-28 07:01:05", +#' "2006-01-24 00:01:10") +#' y <- c(2.0, 2.2, 3.4, 2.5, 1.8) +#' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} +NULL + +#' Math functions for Column operations +#' +#' Math functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and \code{shiftRightUnsigned}, +#' this is the number of bits to shift. +#' @param y Column to compute on. +#' @param ... additional argument(s). +#' @name column_math_functions +#' @rdname column_math_functions +#' @family math functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' tmp <- mutate(df, v1 = log(df$mpg), v2 = cbrt(df$disp), +#' v3 = bround(df$wt, 1), v4 = bin(df$cyl), +#' v5 = hex(df$wt), v6 = toDegrees(df$gear), +#' v7 = atan2(df$cyl, df$am), v8 = hypot(df$cyl, df$am), +#' v9 = pmod(df$hp, df$cyl), v10 = shiftLeft(df$disp, 1), +#' v11 = conv(df$hp, 10, 16), v12 = sign(df$vs - 0.5), +#' v13 = sqrt(df$disp), v14 = ceil(df$wt)) +#' head(tmp)} +NULL + +#' String functions for Column operations +#' +#' String functions defined for \code{Column}. +#' +#' @param x Column to compute on except in the following methods: +#' \itemize{ +#' \item \code{instr}: \code{character}, the substring to check. See 'Details'. +#' \item \code{format_number}: \code{numeric}, the number of decimal place to +#' format to. See 'Details'. +#' } +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_string_functions +#' @rdname column_string_functions +#' @family string functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(as.data.frame(Titanic, stringsAsFactors = FALSE))} +NULL + +#' Non-aggregate functions for Column operations +#' +#' Non-aggregate functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{lit}, it is a literal value or a Column. +#' In \code{expr}, it contains an expression character object to be parsed. +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_nonaggregate_functions +#' @rdname column_nonaggregate_functions +#' @seealso coalesce,SparkDataFrame-method +#' @family non-aggregate functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} +NULL + +#' Miscellaneous functions for Column operations +#' +#' Miscellaneous functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{sha2}, it is one of 224, 256, 384, or 512. +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_misc_functions +#' @rdname column_misc_functions +#' @family misc functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)[, 1:2]) +#' tmp <- mutate(df, v1 = crc32(df$model), v2 = hash(df$model), +#' v3 = hash(df$model, df$mpg), v4 = md5(df$model), +#' v5 = sha1(df$model), v6 = sha2(df$model, 256)) +#' head(tmp)} +NULL + +#' Collection functions for Column operations +#' +#' Collection functions defined for \code{Column}. +#' +#' @param x Column to compute on. Note the difference in the following methods: +#' \itemize{ +#' \item \code{to_json}: it is the column containing the struct, array of the structs, +#' the map or array of maps. +#' \item \code{from_json}: it is the column containing the JSON string. +#' } +#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains +#' additional named properties to control how it is converted, accepts the same +#' options as the JSON data source. +#' @name column_collection_functions +#' @rdname column_collection_functions +#' @family collection functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) +#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) +#' head(tmp2) +#' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, sort_array(tmp$v1))) +#' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) +#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) +#' head(select(tmp3, map_keys(tmp3$v3))) +#' head(select(tmp3, map_values(tmp3$v3)))} +NULL + +#' Window functions for Column operations +#' +#' Window functions defined for \code{Column}. +#' +#' @param x In \code{lag} and \code{lead}, it is the column as a character string or a Column +#' to compute on. In \code{ntile}, it is the number of ntile groups. +#' @param offset In \code{lag}, the number of rows back from the current row from which to obtain +#' a value. In \code{lead}, the number of rows after the current row from which to +#' obtain a value. If not specified, the default is 1. +#' @param defaultValue (optional) default to use when the offset row does not exist. +#' @param ... additional argument(s). +#' @name column_window_functions +#' @rdname column_window_functions +#' @family window functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' tmp <- mutate(df, dist = over(cume_dist(), ws), dense_rank = over(dense_rank(), ws), +#' lag = over(lag(df$mpg), ws), lead = over(lead(df$mpg, 1), ws), +#' percent_rank = over(percent_rank(), ws), +#' rank = over(rank(), ws), row_number = over(row_number(), ws)) +#' # Get ntile group id (1-4) for hp +#' tmp <- mutate(tmp, ntile = over(ntile(4), ws)) +#' head(tmp)} +NULL + +#' @details +#' \code{lit}: A new Column is created to represent the literal value. +#' If the parameter is a Column, it is returned unchanged. +#' +#' @rdname column_nonaggregate_functions #' @export -#' @aliases lit,ANY-method +#' @aliases lit lit,ANY-method #' @examples +#' #' \dontrun{ -#' lit(df$name) -#' select(df, lit("x")) -#' select(df, lit("2015-01-01")) -#'} +#' tmp <- mutate(df, v1 = lit(df$mpg), v2 = lit("x"), v3 = lit("2015-01-01"), +#' v4 = negate(df$mpg), v5 = expr('length(model)'), +#' v6 = greatest(df$vs, df$am), v7 = least(df$vs, df$am), +#' v8 = column("mpg")) +#' head(tmp)} #' @note lit since 1.5.0 setMethod("lit", signature("ANY"), function(x) { @@ -44,18 +254,12 @@ setMethod("lit", signature("ANY"), column(jc) }) -#' abs +#' @details +#' \code{abs}: Computes the absolute value. #' -#' Computes the absolute value. -#' -#' @param x Column to compute on. -#' -#' @rdname abs -#' @name abs -#' @family normal_funcs +#' @rdname column_math_functions #' @export -#' @examples \dontrun{abs(df$c)} -#' @aliases abs,Column-method +#' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", signature(x = "Column"), @@ -64,19 +268,13 @@ setMethod("abs", column(jc) }) -#' acos -#' -#' Computes the cosine inverse of the given value; the returned angle is in the range -#' 0.0 through pi. +#' @details +#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in +#' the range 0.0 through pi. #' -#' @param x Column to compute on. -#' -#' @rdname acos -#' @name acos -#' @family math_funcs +#' @rdname column_math_functions #' @export -#' @examples \dontrun{acos(df$c)} -#' @aliases acos,Column-method +#' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", signature(x = "Column"), @@ -85,17 +283,20 @@ setMethod("acos", column(jc) }) -#' Returns the approximate number of distinct items in a group +#' @details +#' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. #' -#' Returns the approximate number of distinct items in a group. This is a column -#' aggregate function. -#' -#' @rdname approxCountDistinct -#' @name approxCountDistinct -#' @return the approximate number of distinct items in a group. +#' @rdname column_aggregate_functions #' @export -#' @aliases approxCountDistinct,Column-method -#' @examples \dontrun{approxCountDistinct(df$c)} +#' @aliases approxCountDistinct approxCountDistinct,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, approxCountDistinct(df$gear))) +#' head(select(df, approxCountDistinct(df$gear, 0.02))) +#' head(select(df, countDistinct(df$gear, df$cyl))) +#' head(select(df, n_distinct(df$gear))) +#' head(distinct(select(df, "gear")))} #' @note approxCountDistinct(Column) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -104,19 +305,17 @@ setMethod("approxCountDistinct", column(jc) }) -#' ascii +#' @details +#' \code{ascii}: Computes the numeric value of the first character of the string column, +#' and returns the result as an int column. #' -#' Computes the numeric value of the first character of the string column, and returns the -#' result as a int column. -#' -#' @param x Column to compute on. -#' -#' @rdname ascii -#' @name ascii -#' @family string_funcs +#' @rdname column_string_functions #' @export -#' @aliases ascii,Column-method -#' @examples \dontrun{\dontrun{ascii(df$c)}} +#' @aliases ascii ascii,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, ascii(df$Class), ascii(df$Sex)))} #' @note ascii since 1.5.0 setMethod("ascii", signature(x = "Column"), @@ -125,19 +324,13 @@ setMethod("ascii", column(jc) }) -#' asin -#' -#' Computes the sine inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. +#' @details +#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in +#' the range -pi/2 through pi/2. #' -#' @param x Column to compute on. -#' -#' @rdname asin -#' @name asin -#' @family math_funcs +#' @rdname column_math_functions #' @export -#' @aliases asin,Column-method -#' @examples \dontrun{asin(df$c)} +#' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", signature(x = "Column"), @@ -146,18 +339,13 @@ setMethod("asin", column(jc) }) -#' atan -#' -#' Computes the tangent inverse of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range +#' -pi/2 through pi/2. #' -#' @rdname atan -#' @name atan -#' @family math_funcs +#' @rdname column_math_functions #' @export -#' @aliases atan,Column-method -#' @examples \dontrun{atan(df$c)} +#' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", signature(x = "Column"), @@ -172,7 +360,7 @@ setMethod("atan", #' #' @rdname avg #' @name avg -#' @family agg_funcs +#' @family aggregate functions #' @export #' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} @@ -184,19 +372,22 @@ setMethod("avg", column(jc) }) -#' base64 -#' -#' Computes the BASE64 encoding of a binary column and returns it as a string column. -#' This is the reverse of unbase64. +#' @details +#' \code{base64}: Computes the BASE64 encoding of a binary column and returns it as +#' a string column. This is the reverse of unbase64. #' -#' @param x Column to compute on. -#' -#' @rdname base64 -#' @name base64 -#' @family string_funcs +#' @rdname column_string_functions #' @export -#' @aliases base64,Column-method -#' @examples \dontrun{base64(df$c)} +#' @aliases base64 base64,Column-method +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = encode(df$Class, "UTF-8")) +#' str(tmp) +#' tmp2 <- mutate(tmp, s2 = base64(tmp$s1), s3 = decode(tmp$s1, "UTF-8"), +#' s4 = soundex(tmp$Sex)) +#' head(tmp2) +#' head(select(tmp2, unbase64(tmp2$s2)))} #' @note base64 since 1.5.0 setMethod("base64", signature(x = "Column"), @@ -205,19 +396,13 @@ setMethod("base64", column(jc) }) -#' bin -#' -#' An expression that returns the string representation of the binary value of the given long -#' column. For example, bin("12") returns "1100". -#' -#' @param x Column to compute on. +#' @details +#' \code{bin}: Returns the string representation of the binary value +#' of the given long column. For example, bin("12") returns "1100". #' -#' @rdname bin -#' @name bin -#' @family math_funcs +#' @rdname column_math_functions #' @export -#' @aliases bin,Column-method -#' @examples \dontrun{bin(df$c)} +#' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", signature(x = "Column"), @@ -226,18 +411,16 @@ setMethod("bin", column(jc) }) -#' bitwiseNOT +#' @details +#' \code{bitwiseNOT}: Computes bitwise NOT. #' -#' Computes bitwise NOT. -#' -#' @param x Column to compute on. -#' -#' @rdname bitwiseNOT -#' @name bitwiseNOT -#' @family normal_funcs +#' @rdname column_nonaggregate_functions #' @export -#' @aliases bitwiseNOT,Column-method -#' @examples \dontrun{bitwiseNOT(df$c)} +#' @aliases bitwiseNOT bitwiseNOT,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, bitwiseNOT(cast(df$vs, "int"))))} #' @note bitwiseNOT since 1.5.0 setMethod("bitwiseNOT", signature(x = "Column"), @@ -246,18 +429,12 @@ setMethod("bitwiseNOT", column(jc) }) -#' cbrt -#' -#' Computes the cube-root of the given value. +#' @details +#' \code{cbrt}: Computes the cube-root of the given value. #' -#' @param x Column to compute on. -#' -#' @rdname cbrt -#' @name cbrt -#' @family math_funcs +#' @rdname column_math_functions #' @export -#' @aliases cbrt,Column-method -#' @examples \dontrun{cbrt(df$c)} +#' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", signature(x = "Column"), @@ -266,18 +443,12 @@ setMethod("cbrt", column(jc) }) -#' Computes the ceiling of the given value +#' @details +#' \code{ceil}: Computes the ceiling of the given value. #' -#' Computes the ceiling of the given value. -#' -#' @param x Column to compute on. -#' -#' @rdname ceil -#' @name ceil -#' @family math_funcs +#' @rdname column_math_functions #' @export -#' @aliases ceil,Column-method -#' @examples \dontrun{ceil(df$c)} +#' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", signature(x = "Column"), @@ -286,16 +457,25 @@ setMethod("ceil", column(jc) }) -#' Returns the first column that is not NA +#' @details +#' \code{ceiling}: Alias for \code{ceil}. #' -#' Returns the first column that is not NA, or NA if all inputs are. +#' @rdname column_math_functions +#' @aliases ceiling ceiling,Column-method +#' @export +#' @note ceiling since 1.5.0 +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + +#' @details +#' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' -#' @rdname coalesce -#' @name coalesce -#' @family normal_funcs +#' @rdname column_nonaggregate_functions #' @export #' @aliases coalesce,Column-method -#' @examples \dontrun{coalesce(df$c, df$d, df$e)} #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", signature(x = "Column"), @@ -324,7 +504,7 @@ col <- function(x) { #' #' @rdname column #' @name column -#' @family normal_funcs +#' @family non-aggregate functions #' @export #' @aliases column,character-method #' @examples \dontrun{column("name")} @@ -334,6 +514,7 @@ setMethod("column", function(x) { col(x) }) + #' corr #' #' Computes the Pearson Correlation Coefficient for two Columns. @@ -342,10 +523,13 @@ setMethod("column", #' #' @rdname corr #' @name corr -#' @family math_funcs +#' @family aggregate functions #' @export #' @aliases corr,Column-method -#' @examples \dontrun{corr(df$c, df$d)} +#' @examples +#' \dontrun{ +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' head(select(df, corr(df$mpg, df$hp)))} #' @note corr since 1.6.0 setMethod("corr", signature(x = "Column"), function(x, col2) { @@ -356,20 +540,22 @@ setMethod("corr", signature(x = "Column"), #' cov #' -#' Compute the sample covariance between two expressions. +#' Compute the covariance between two expressions. +#' +#' @details +#' \code{cov}: Compute the sample covariance between two expressions. #' #' @rdname cov #' @name cov -#' @family math_funcs +#' @family aggregate functions #' @export #' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ -#' cov(df$c, df$d) -#' cov("c", "d") -#' covar_samp(df$c, df$d) -#' covar_samp("c", "d") -#' } +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' head(select(df, cov(df$mpg, df$hp), cov("mpg", "hp"), +#' covar_samp(df$mpg, df$hp), covar_samp("mpg", "hp"), +#' covar_pop(df$mpg, df$hp), covar_pop("mpg", "hp")))} #' @note cov since 1.6.0 setMethod("cov", signature(x = "characterOrColumn"), function(x, col2) { @@ -377,6 +563,9 @@ setMethod("cov", signature(x = "characterOrColumn"), covar_samp(x, col2) }) +#' @details +#' \code{covar_sample}: Alias for \code{cov}. +#' #' @rdname cov #' #' @param col1 the first Column. @@ -395,23 +584,13 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO column(jc) }) -#' covar_pop -#' -#' Compute the population covariance between two expressions. +#' @details +#' \code{covar_pop}: Computes the population covariance between two expressions. #' -#' @param col1 First column to compute cov_pop. -#' @param col2 Second column to compute cov_pop. -#' -#' @rdname covar_pop +#' @rdname cov #' @name covar_pop -#' @family math_funcs #' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method -#' @examples -#' \dontrun{ -#' covar_pop(df$c, df$d) -#' covar_pop("c", "d") -#' } #' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), function(col1, col2) { @@ -424,18 +603,12 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr column(jc) }) -#' cos -#' -#' Computes the cosine of the given value. +#' @details +#' \code{cos}: Computes the cosine of the given value. Units in radians. #' -#' @param x Column to compute on. -#' -#' @rdname cos -#' @name cos -#' @family math_funcs -#' @aliases cos,Column-method +#' @rdname column_math_functions +#' @aliases cos cos,Column-method #' @export -#' @examples \dontrun{cos(df$c)} #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -444,18 +617,12 @@ setMethod("cos", column(jc) }) -#' cosh -#' -#' Computes the hyperbolic cosine of the given value. +#' @details +#' \code{cosh}: Computes the hyperbolic cosine of the given value. #' -#' @param x Column to compute on. -#' -#' @rdname cosh -#' @name cosh -#' @family math_funcs -#' @aliases cosh,Column-method +#' @rdname column_math_functions +#' @aliases cosh cosh,Column-method #' @export -#' @examples \dontrun{cosh(df$c)} #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -471,7 +638,7 @@ setMethod("cosh", #' #' @rdname count #' @name count -#' @family agg_funcs +#' @family aggregate functions #' @aliases count,Column-method #' @export #' @examples \dontrun{count(df$c)} @@ -483,19 +650,13 @@ setMethod("count", column(jc) }) -#' crc32 -#' -#' Calculates the cyclic redundancy check value (CRC32) of a binary column and -#' returns the value as a bigint. -#' -#' @param x Column to compute on. +#' @details +#' \code{crc32}: Calculates the cyclic redundancy check value (CRC32) of a binary column +#' and returns the value as a bigint. #' -#' @rdname crc32 -#' @name crc32 -#' @family misc_funcs -#' @aliases crc32,Column-method +#' @rdname column_misc_functions +#' @aliases crc32 crc32,Column-method #' @export -#' @examples \dontrun{crc32(df$c)} #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -504,19 +665,13 @@ setMethod("crc32", column(jc) }) -#' hash -#' -#' Calculates the hash code of given columns, and returns the result as a int column. -#' -#' @param x Column to compute on. -#' @param ... additional Column(s) to be included. +#' @details +#' \code{hash}: Calculates the hash code of given columns, and returns the result +#' as an int column. #' -#' @rdname hash -#' @name hash -#' @family misc_funcs -#' @aliases hash,Column-method +#' @rdname column_misc_functions +#' @aliases hash hash,Column-method #' @export -#' @examples \dontrun{hash(df$c)} #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -529,18 +684,20 @@ setMethod("hash", column(jc) }) -#' dayofmonth +#' @details +#' \code{dayofmonth}: Extracts the day of the month as an integer from a +#' given date/timestamp/string. #' -#' Extracts the day of the month as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. -#' -#' @rdname dayofmonth -#' @name dayofmonth -#' @family datetime_funcs -#' @aliases dayofmonth,Column-method +#' @rdname column_datetime_functions +#' @aliases dayofmonth dayofmonth,Column-method #' @export -#' @examples \dontrun{dayofmonth(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, year(df$time), quarter(df$time), month(df$time), +#' dayofmonth(df$time), dayofyear(df$time), weekofyear(df$time))) +#' head(agg(groupBy(df, year(df$time)), count(df$y), avg(df$y))) +#' head(agg(groupBy(df, month(df$time)), avg(df$y)))} #' @note dayofmonth since 1.5.0 setMethod("dayofmonth", signature(x = "Column"), @@ -549,18 +706,13 @@ setMethod("dayofmonth", column(jc) }) -#' dayofyear -#' -#' Extracts the day of the year as an integer from a given date/timestamp/string. +#' @details +#' \code{dayofyear}: Extracts the day of the year as an integer from a +#' given date/timestamp/string. #' -#' @param x Column to compute on. -#' -#' @rdname dayofyear -#' @name dayofyear -#' @family datetime_funcs -#' @aliases dayofyear,Column-method +#' @rdname column_datetime_functions +#' @aliases dayofyear dayofyear,Column-method #' @export -#' @examples \dontrun{dayofyear(df$c)} #' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), @@ -569,20 +721,16 @@ setMethod("dayofyear", column(jc) }) -#' decode +#' @details +#' \code{decode}: Computes the first argument into a string from a binary using the provided +#' character set. #' -#' Computes the first argument into a string from a binary using the provided character set -#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). -#' -#' @param x Column to compute on. -#' @param charset Character set to use +#' @param charset character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", +#' "UTF-16LE", "UTF-16"). #' -#' @rdname decode -#' @name decode -#' @family string_funcs -#' @aliases decode,Column,character-method +#' @rdname column_string_functions +#' @aliases decode decode,Column,character-method #' @export -#' @examples \dontrun{decode(df$c, "UTF-8")} #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -591,20 +739,13 @@ setMethod("decode", column(jc) }) -#' encode +#' @details +#' \code{encode}: Computes the first argument into a binary from a string using the provided +#' character set. #' -#' Computes the first argument into a binary from a string using the provided character set -#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). -#' -#' @param x Column to compute on. -#' @param charset Character set to use -#' -#' @rdname encode -#' @name encode -#' @family string_funcs -#' @aliases encode,Column,character-method +#' @rdname column_string_functions +#' @aliases encode encode,Column,character-method #' @export -#' @examples \dontrun{encode(df$c, "UTF-8")} #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -613,18 +754,12 @@ setMethod("encode", column(jc) }) -#' exp -#' -#' Computes the exponential of the given value. +#' @details +#' \code{exp}: Computes the exponential of the given value. #' -#' @param x Column to compute on. -#' -#' @rdname exp -#' @name exp -#' @family math_funcs -#' @aliases exp,Column-method +#' @rdname column_math_functions +#' @aliases exp exp,Column-method #' @export -#' @examples \dontrun{exp(df$c)} #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -633,18 +768,12 @@ setMethod("exp", column(jc) }) -#' expm1 +#' @details +#' \code{expm1}: Computes the exponential of the given value minus one. #' -#' Computes the exponential of the given value minus one. -#' -#' @param x Column to compute on. -#' -#' @rdname expm1 -#' @name expm1 -#' @aliases expm1,Column-method -#' @family math_funcs +#' @rdname column_math_functions +#' @aliases expm1 expm1,Column-method #' @export -#' @examples \dontrun{expm1(df$c)} #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -653,18 +782,12 @@ setMethod("expm1", column(jc) }) -#' factorial -#' -#' Computes the factorial of the given value. +#' @details +#' \code{factorial}: Computes the factorial of the given value. #' -#' @param x Column to compute on. -#' -#' @rdname factorial -#' @name factorial -#' @aliases factorial,Column-method -#' @family math_funcs +#' @rdname column_math_functions +#' @aliases factorial factorial,Column-method #' @export -#' @examples \dontrun{factorial(df$c)} #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -686,7 +809,7 @@ setMethod("factorial", #' @rdname first #' @name first #' @aliases first,characterOrColumn-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples #' \dontrun{ @@ -706,18 +829,12 @@ setMethod("first", column(jc) }) -#' floor -#' -#' Computes the floor of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{floor}: Computes the floor of the given value. #' -#' @rdname floor -#' @name floor -#' @aliases floor,Column-method -#' @family math_funcs +#' @rdname column_math_functions +#' @aliases floor floor,Column-method #' @export -#' @examples \dontrun{floor(df$c)} #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -726,18 +843,12 @@ setMethod("floor", column(jc) }) -#' hex -#' -#' Computes hex value of the given column. -#' -#' @param x Column to compute on. +#' @details +#' \code{hex}: Computes hex value of the given column. #' -#' @rdname hex -#' @name hex -#' @family math_funcs -#' @aliases hex,Column-method +#' @rdname column_math_functions +#' @aliases hex hex,Column-method #' @export -#' @examples \dontrun{hex(df$c)} #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -746,18 +857,19 @@ setMethod("hex", column(jc) }) -#' hour +#' @details +#' \code{hour}: Extracts the hour as an integer from a given date/timestamp/string. #' -#' Extracts the hours as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. -#' -#' @rdname hour -#' @name hour -#' @aliases hour,Column-method -#' @family datetime_funcs +#' @rdname column_datetime_functions +#' @aliases hour hour,Column-method #' @export -#' @examples \dontrun{hour(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, hour(df$time), minute(df$time), second(df$time))) +#' head(agg(groupBy(df, dayofmonth(df$time)), avg(df$y))) +#' head(agg(groupBy(df, hour(df$time)), avg(df$y))) +#' head(agg(groupBy(df, minute(df$time)), avg(df$y)))} #' @note hour since 1.5.0 setMethod("hour", signature(x = "Column"), @@ -766,21 +878,23 @@ setMethod("hour", column(jc) }) -#' initcap -#' -#' Returns a new string column by converting the first letter of each word to uppercase. -#' Words are delimited by whitespace. +#' @details +#' \code{initcap}: Returns a new string column by converting the first letter of +#' each word to uppercase. Words are delimited by whitespace. For example, "hello world" +#' will become "Hello World". #' -#' For example, "hello world" will become "Hello World". -#' -#' @param x Column to compute on. -#' -#' @rdname initcap -#' @name initcap -#' @family string_funcs -#' @aliases initcap,Column-method +#' @rdname column_string_functions +#' @aliases initcap initcap,Column-method #' @export -#' @examples \dontrun{initcap(df$c)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, sex_lower = lower(df$Sex), age_upper = upper(df$age), +#' sex_age = concat_ws(" ", lower(df$sex), lower(df$age))) +#' head(tmp) +#' tmp2 <- mutate(tmp, s1 = initcap(tmp$sex_lower), s2 = initcap(tmp$sex_age), +#' s3 = reverse(df$Sex)) +#' head(tmp2)} #' @note initcap since 1.5.0 setMethod("initcap", signature(x = "Column"), @@ -789,32 +903,10 @@ setMethod("initcap", column(jc) }) -#' is.nan -#' -#' Return true if the column is NaN, alias for \link{isnan} -#' -#' @param x Column to compute on. -#' -#' @rdname is.nan -#' @name is.nan -#' @family normal_funcs -#' @aliases is.nan,Column-method -#' @export -#' @examples -#' \dontrun{ -#' is.nan(df$c) -#' isnan(df$c) -#' } -#' @note is.nan since 2.0.0 -setMethod("is.nan", - signature(x = "Column"), - function(x) { - isnan(x) - }) - -#' @rdname is.nan -#' @name isnan -#' @aliases isnan,Column-method +#' @details +#' \code{isnan}: Returns true if the column is NaN. +#' @rdname column_nonaggregate_functions +#' @aliases isnan isnan,Column-method #' @note isnan since 2.0.0 setMethod("isnan", signature(x = "Column"), @@ -823,18 +915,29 @@ setMethod("isnan", column(jc) }) -#' kurtosis -#' -#' Aggregate function: returns the kurtosis of the values in a group. +#' @details +#' \code{is.nan}: Alias for \link{isnan}. #' -#' @param x Column to compute on. +#' @rdname column_nonaggregate_functions +#' @aliases is.nan is.nan,Column-method +#' @export +#' @note is.nan since 2.0.0 +setMethod("is.nan", + signature(x = "Column"), + function(x) { + isnan(x) + }) + +#' @details +#' \code{kurtosis}: Returns the kurtosis of the values in a group. #' -#' @rdname kurtosis -#' @name kurtosis -#' @aliases kurtosis,Column-method -#' @family agg_funcs +#' @rdname column_aggregate_functions +#' @aliases kurtosis kurtosis,Column-method #' @export -#' @examples \dontrun{kurtosis(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, mean(df$mpg), sd(df$mpg), skewness(df$mpg), kurtosis(df$mpg)))} #' @note kurtosis since 1.6.0 setMethod("kurtosis", signature(x = "Column"), @@ -858,7 +961,7 @@ setMethod("kurtosis", #' @rdname last #' @name last #' @aliases last,characterOrColumn-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples #' \dontrun{ @@ -878,20 +981,18 @@ setMethod("last", column(jc) }) -#' last_day -#' -#' Given a date column, returns the last day of the month which the given date belongs to. -#' For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the -#' month in July 2015. -#' -#' @param x Column to compute on. +#' @details +#' \code{last_day}: Given a date column, returns the last day of the month which the +#' given date belongs to. For example, input "2015-07-27" returns "2015-07-31" since +#' July 31 is the last day of the month in July 2015. #' -#' @rdname last_day -#' @name last_day -#' @aliases last_day,Column-method -#' @family datetime_funcs +#' @rdname column_datetime_functions +#' @aliases last_day last_day,Column-method #' @export -#' @examples \dontrun{last_day(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, last_day(df$time), month(df$time)))} #' @note last_day since 1.5.0 setMethod("last_day", signature(x = "Column"), @@ -900,18 +1001,12 @@ setMethod("last_day", column(jc) }) -#' length -#' -#' Computes the length of a given string or binary column. +#' @details +#' \code{length}: Computes the length of a given string or binary column. #' -#' @param x Column to compute on. -#' -#' @rdname length -#' @name length -#' @aliases length,Column-method -#' @family string_funcs +#' @rdname column_string_functions +#' @aliases length length,Column-method #' @export -#' @examples \dontrun{length(df$c)} #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -920,18 +1015,12 @@ setMethod("length", column(jc) }) -#' log -#' -#' Computes the natural logarithm of the given value. +#' @details +#' \code{log}: Computes the natural logarithm of the given value. #' -#' @param x Column to compute on. -#' -#' @rdname log -#' @name log -#' @aliases log,Column-method -#' @family math_funcs +#' @rdname column_math_functions +#' @aliases log log,Column-method #' @export -#' @examples \dontrun{log(df$c)} #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -940,18 +1029,12 @@ setMethod("log", column(jc) }) -#' log10 +#' @details +#' \code{log10}: Computes the logarithm of the given value in base 10. #' -#' Computes the logarithm of the given value in base 10. -#' -#' @param x Column to compute on. -#' -#' @rdname log10 -#' @name log10 -#' @family math_funcs -#' @aliases log10,Column-method +#' @rdname column_math_functions +#' @aliases log10 log10,Column-method #' @export -#' @examples \dontrun{log10(df$c)} #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -960,18 +1043,12 @@ setMethod("log10", column(jc) }) -#' log1p -#' -#' Computes the natural logarithm of the given value plus one. -#' -#' @param x Column to compute on. +#' @details +#' \code{log1p}: Computes the natural logarithm of the given value plus one. #' -#' @rdname log1p -#' @name log1p -#' @family math_funcs -#' @aliases log1p,Column-method +#' @rdname column_math_functions +#' @aliases log1p log1p,Column-method #' @export -#' @examples \dontrun{log1p(df$c)} #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -980,18 +1057,12 @@ setMethod("log1p", column(jc) }) -#' log2 -#' -#' Computes the logarithm of the given column in base 2. -#' -#' @param x Column to compute on. +#' @details +#' \code{log2}: Computes the logarithm of the given column in base 2. #' -#' @rdname log2 -#' @name log2 -#' @family math_funcs -#' @aliases log2,Column-method +#' @rdname column_math_functions +#' @aliases log2 log2,Column-method #' @export -#' @examples \dontrun{log2(df$c)} #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1000,18 +1071,12 @@ setMethod("log2", column(jc) }) -#' lower -#' -#' Converts a string column to lower case. -#' -#' @param x Column to compute on. +#' @details +#' \code{lower}: Converts a string column to lower case. #' -#' @rdname lower -#' @name lower -#' @family string_funcs -#' @aliases lower,Column-method +#' @rdname column_string_functions +#' @aliases lower lower,Column-method #' @export -#' @examples \dontrun{lower(df$c)} #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1020,18 +1085,24 @@ setMethod("lower", column(jc) }) -#' ltrim +#' @details +#' \code{ltrim}: Trims the spaces from left end for the specified string value. #' -#' Trim the spaces from left end for the specified string value. -#' -#' @param x Column to compute on. -#' -#' @rdname ltrim -#' @name ltrim -#' @family string_funcs -#' @aliases ltrim,Column-method +#' @rdname column_string_functions +#' @aliases ltrim ltrim,Column-method #' @export -#' @examples \dontrun{ltrim(df$c)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, SexLpad = lpad(df$Sex, 6, " "), SexRpad = rpad(df$Sex, 7, " ")) +#' head(select(tmp, length(tmp$Sex), length(tmp$SexLpad), length(tmp$SexRpad))) +#' tmp2 <- mutate(tmp, SexLtrim = ltrim(tmp$SexLpad), SexRtrim = rtrim(tmp$SexRpad), +#' SexTrim = trim(tmp$SexLpad)) +#' head(select(tmp2, length(tmp2$Sex), length(tmp2$SexLtrim), +#' length(tmp2$SexRtrim), length(tmp2$SexTrim))) +#' +#' tmp <- mutate(df, SexLpad = lpad(df$Sex, 6, "xx"), SexRpad = rpad(df$Sex, 7, "xx")) +#' head(tmp)} #' @note ltrim since 1.5.0 setMethod("ltrim", signature(x = "Column"), @@ -1040,18 +1111,11 @@ setMethod("ltrim", column(jc) }) -#' max +#' @details +#' \code{max}: Returns the maximum value of the expression in a group. #' -#' Aggregate function: returns the maximum value of the expression in a group. -#' -#' @param x Column to compute on. -#' -#' @rdname max -#' @name max -#' @family agg_funcs -#' @aliases max,Column-method -#' @export -#' @examples \dontrun{max(df$c)} +#' @rdname column_aggregate_functions +#' @aliases max max,Column-method #' @note max since 1.5.0 setMethod("max", signature(x = "Column"), @@ -1060,19 +1124,13 @@ setMethod("max", column(jc) }) -#' md5 -#' -#' Calculates the MD5 digest of a binary column and returns the value +#' @details +#' \code{md5}: Calculates the MD5 digest of a binary column and returns the value #' as a 32 character hex string. #' -#' @param x Column to compute on. -#' -#' @rdname md5 -#' @name md5 -#' @family misc_funcs -#' @aliases md5,Column-method +#' @rdname column_misc_functions +#' @aliases md5 md5,Column-method #' @export -#' @examples \dontrun{md5(df$c)} #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1081,19 +1139,24 @@ setMethod("md5", column(jc) }) -#' mean +#' @details +#' \code{mean}: Returns the average of the values in a group. Alias for \code{avg}. #' -#' Aggregate function: returns the average of the values in a group. -#' Alias for avg. +#' @rdname column_aggregate_functions +#' @aliases mean mean,Column-method +#' @export +#' @examples #' -#' @param x Column to compute on. +#' \dontrun{ +#' head(select(df, avg(df$mpg), mean(df$mpg), sum(df$mpg), min(df$wt), max(df$qsec))) #' -#' @rdname mean -#' @name mean -#' @family agg_funcs -#' @aliases mean,Column-method -#' @export -#' @examples \dontrun{mean(df$c)} +#' # metrics by num of cylinders +#' tmp <- agg(groupBy(df, "cyl"), avg(df$mpg), avg(df$hp), avg(df$wt), avg(df$qsec)) +#' head(orderBy(tmp, "cyl")) +#' +#' # car with the max mpg +#' mpg_max <- as.numeric(collect(agg(df, max(df$mpg)))) +#' head(where(df, df$mpg == mpg_max))} #' @note mean since 1.5.0 setMethod("mean", signature(x = "Column"), @@ -1102,18 +1165,12 @@ setMethod("mean", column(jc) }) -#' min -#' -#' Aggregate function: returns the minimum value of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{min}: Returns the minimum value of the expression in a group. #' -#' @rdname min -#' @name min -#' @aliases min,Column-method -#' @family agg_funcs +#' @rdname column_aggregate_functions +#' @aliases min min,Column-method #' @export -#' @examples \dontrun{min(df$c)} #' @note min since 1.5.0 setMethod("min", signature(x = "Column"), @@ -1122,18 +1179,12 @@ setMethod("min", column(jc) }) -#' minute +#' @details +#' \code{minute}: Extracts the minute as an integer from a given date/timestamp/string. #' -#' Extracts the minutes as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. -#' -#' @rdname minute -#' @name minute -#' @aliases minute,Column-method -#' @family datetime_funcs +#' @rdname column_datetime_functions +#' @aliases minute minute,Column-method #' @export -#' @examples \dontrun{minute(df$c)} #' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), @@ -1142,27 +1193,24 @@ setMethod("minute", column(jc) }) -#' monotonically_increasing_id -#' -#' Return a column that generates monotonically increasing 64-bit integers. -#' -#' The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. -#' The current implementation puts the partition ID in the upper 31 bits, and the record number -#' within each partition in the lower 33 bits. The assumption is that the SparkDataFrame has -#' less than 1 billion partitions, and each partition has less than 8 billion records. -#' -#' As an example, consider a SparkDataFrame with two partitions, each with 3 records. +#' @details +#' \code{monotonically_increasing_id}: Returns a column that generates monotonically increasing +#' 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, +#' but not consecutive. The current implementation puts the partition ID in the upper 31 bits, +#' and the record number within each partition in the lower 33 bits. The assumption is that the +#' SparkDataFrame has less than 1 billion partitions, and each partition has less than 8 billion +#' records. As an example, consider a SparkDataFrame with two partitions, each with 3 records. #' This expression would return the following IDs: #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. -#' #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. +#' The method should be used with no argument. #' -#' @rdname monotonically_increasing_id -#' @aliases monotonically_increasing_id,missing-method -#' @name monotonically_increasing_id -#' @family misc_funcs +#' @rdname column_nonaggregate_functions +#' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method #' @export -#' @examples \dontrun{select(df, monotonically_increasing_id())} +#' @examples +#' +#' \dontrun{head(select(df, monotonically_increasing_id()))} setMethod("monotonically_increasing_id", signature("missing"), function() { @@ -1170,18 +1218,12 @@ setMethod("monotonically_increasing_id", column(jc) }) -#' month -#' -#' Extracts the month as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{month}: Extracts the month as an integer from a given date/timestamp/string. #' -#' @rdname month -#' @name month -#' @aliases month,Column-method -#' @family datetime_funcs +#' @rdname column_datetime_functions +#' @aliases month month,Column-method #' @export -#' @examples \dontrun{month(df$c)} #' @note month since 1.5.0 setMethod("month", signature(x = "Column"), @@ -1190,18 +1232,12 @@ setMethod("month", column(jc) }) -#' negate -#' -#' Unary minus, i.e. negate the expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{negate}: Unary minus, i.e. negate the expression. #' -#' @rdname negate -#' @name negate -#' @family normal_funcs -#' @aliases negate,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases negate negate,Column-method #' @export -#' @examples \dontrun{negate(df$c)} #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1210,18 +1246,12 @@ setMethod("negate", column(jc) }) -#' quarter +#' @details +#' \code{quarter}: Extracts the quarter as an integer from a given date/timestamp/string. #' -#' Extracts the quarter as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. -#' -#' @rdname quarter -#' @name quarter -#' @family datetime_funcs -#' @aliases quarter,Column-method +#' @rdname column_datetime_functions +#' @aliases quarter quarter,Column-method #' @export -#' @examples \dontrun{quarter(df$c)} #' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), @@ -1230,18 +1260,12 @@ setMethod("quarter", column(jc) }) -#' reverse -#' -#' Reverses the string column and returns it as a new string column. +#' @details +#' \code{reverse}: Reverses the string column and returns it as a new string column. #' -#' @param x Column to compute on. -#' -#' @rdname reverse -#' @name reverse -#' @family string_funcs -#' @aliases reverse,Column-method +#' @rdname column_string_functions +#' @aliases reverse reverse,Column-method #' @export -#' @examples \dontrun{reverse(df$c)} #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1250,19 +1274,13 @@ setMethod("reverse", column(jc) }) -#' rint -#' -#' Returns the double value that is closest in value to the argument and +#' @details +#' \code{rint}: Returns the double value that is closest in value to the argument and #' is equal to a mathematical integer. #' -#' @param x Column to compute on. -#' -#' @rdname rint -#' @name rint -#' @family math_funcs -#' @aliases rint,Column-method +#' @rdname column_math_functions +#' @aliases rint rint,Column-method #' @export -#' @examples \dontrun{rint(df$c)} #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1271,18 +1289,13 @@ setMethod("rint", column(jc) }) -#' round +#' @details +#' \code{round}: Returns the value of the column rounded to 0 decimal places +#' using HALF_UP rounding mode. #' -#' Returns the value of the column \code{e} rounded to 0 decimal places using HALF_UP rounding mode. -#' -#' @param x Column to compute on. -#' -#' @rdname round -#' @name round -#' @family math_funcs -#' @aliases round,Column-method +#' @rdname column_math_functions +#' @aliases round round,Column-method #' @export -#' @examples \dontrun{round(df$c)} #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1291,24 +1304,18 @@ setMethod("round", column(jc) }) -#' bround -#' -#' Returns the value of the column \code{e} rounded to \code{scale} decimal places using HALF_EVEN rounding -#' mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. +#' @details +#' \code{bround}: Returns the value of the column \code{e} rounded to \code{scale} decimal places +#' using HALF_EVEN rounding mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' -#' @param x Column to compute on. #' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, #' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left #' of the decimal point when \code{scale} < 0. -#' @param ... further arguments to be passed to or from other methods. -#' @rdname bround -#' @name bround -#' @family math_funcs -#' @aliases bround,Column-method +#' @rdname column_math_functions +#' @aliases bround bround,Column-method #' @export -#' @examples \dontrun{bround(df$c, 0)} #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1317,19 +1324,12 @@ setMethod("bround", column(jc) }) - -#' rtrim +#' @details +#' \code{rtrim}: Trims the spaces from right end for the specified string value. #' -#' Trim the spaces from right end for the specified string value. -#' -#' @param x Column to compute on. -#' -#' @rdname rtrim -#' @name rtrim -#' @family string_funcs -#' @aliases rtrim,Column-method +#' @rdname column_string_functions +#' @aliases rtrim rtrim,Column-method #' @export -#' @examples \dontrun{rtrim(df$c)} #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column"), @@ -1338,24 +1338,16 @@ setMethod("rtrim", column(jc) }) -#' sd -#' -#' Aggregate function: alias for \link{stddev_samp} +#' @details +#' \code{sd}: Alias for \code{stddev_samp}. #' -#' @param x Column to compute on. -#' @param na.rm currently not used. -#' @rdname sd -#' @name sd -#' @family agg_funcs -#' @aliases sd,Column-method -#' @seealso \link{stddev_pop}, \link{stddev_samp} +#' @rdname column_aggregate_functions +#' @aliases sd sd,Column-method #' @export #' @examples -#'\dontrun{ -#'stddev(df$c) -#'select(df, stddev(df$age)) -#'agg(df, sd(df$age)) -#'} +#' +#' \dontrun{ +#' head(select(df, sd(df$mpg), stddev(df$mpg), stddev_pop(df$wt), stddev_samp(df$qsec)))} #' @note sd since 1.6.0 setMethod("sd", signature(x = "Column"), @@ -1364,18 +1356,12 @@ setMethod("sd", stddev_samp(x) }) -#' second -#' -#' Extracts the seconds as an integer from a given date/timestamp/string. +#' @details +#' \code{second}: Extracts the second as an integer from a given date/timestamp/string. #' -#' @param x Column to compute on. -#' -#' @rdname second -#' @name second -#' @family datetime_funcs -#' @aliases second,Column-method +#' @rdname column_datetime_functions +#' @aliases second second,Column-method #' @export -#' @examples \dontrun{second(df$c)} #' @note second since 1.5.0 setMethod("second", signature(x = "Column"), @@ -1384,19 +1370,13 @@ setMethod("second", column(jc) }) -#' sha1 -#' -#' Calculates the SHA-1 digest of a binary column and returns the value +#' @details +#' \code{sha1}: Calculates the SHA-1 digest of a binary column and returns the value #' as a 40 character hex string. #' -#' @param x Column to compute on. -#' -#' @rdname sha1 -#' @name sha1 -#' @family misc_funcs -#' @aliases sha1,Column-method +#' @rdname column_misc_functions +#' @aliases sha1 sha1,Column-method #' @export -#' @examples \dontrun{sha1(df$c)} #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -1405,18 +1385,12 @@ setMethod("sha1", column(jc) }) -#' signum -#' -#' Computes the signum of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{signum}: Computes the signum of the given value. #' -#' @rdname sign -#' @name signum -#' @aliases signum,Column-method -#' @family math_funcs +#' @rdname column_math_functions +#' @aliases signum signum,Column-method #' @export -#' @examples \dontrun{signum(df$c)} #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1425,18 +1399,24 @@ setMethod("signum", column(jc) }) -#' sin +#' @details +#' \code{sign}: Alias for \code{signum}. #' -#' Computes the sine of the given value. -#' -#' @param x Column to compute on. +#' @rdname column_math_functions +#' @aliases sign sign,Column-method +#' @export +#' @note sign since 1.5.0 +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' @details +#' \code{sin}: Computes the sine of the given value. Units in radians. #' -#' @rdname sin -#' @name sin -#' @family math_funcs -#' @aliases sin,Column-method +#' @rdname column_math_functions +#' @aliases sin sin,Column-method #' @export -#' @examples \dontrun{sin(df$c)} #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1445,18 +1425,12 @@ setMethod("sin", column(jc) }) -#' sinh +#' @details +#' \code{sinh}: Computes the hyperbolic sine of the given value. #' -#' Computes the hyperbolic sine of the given value. -#' -#' @param x Column to compute on. -#' -#' @rdname sinh -#' @name sinh -#' @family math_funcs -#' @aliases sinh,Column-method +#' @rdname column_math_functions +#' @aliases sinh sinh,Column-method #' @export -#' @examples \dontrun{sinh(df$c)} #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1465,18 +1439,12 @@ setMethod("sinh", column(jc) }) -#' skewness -#' -#' Aggregate function: returns the skewness of the values in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{skewness}: Returns the skewness of the values in a group. #' -#' @rdname skewness -#' @name skewness -#' @family agg_funcs -#' @aliases skewness,Column-method +#' @rdname column_aggregate_functions +#' @aliases skewness skewness,Column-method #' @export -#' @examples \dontrun{skewness(df$c)} #' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), @@ -1485,18 +1453,12 @@ setMethod("skewness", column(jc) }) -#' soundex -#' -#' Return the soundex code for the specified expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{soundex}: Returns the soundex code for the specified expression. #' -#' @rdname soundex -#' @name soundex -#' @family string_funcs -#' @aliases soundex,Column-method +#' @rdname column_string_functions +#' @aliases soundex soundex,Column-method #' @export -#' @examples \dontrun{soundex(df$c)} #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1505,20 +1467,18 @@ setMethod("soundex", column(jc) }) -#' Return the partition ID as a column -#' -#' Return the partition ID as a SparkDataFrame column. +#' @details +#' \code{spark_partition_id}: Returns the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. +#' This is equivalent to the \code{SPARK_PARTITION_ID} function in SQL. #' -#' This is equivalent to the SPARK_PARTITION_ID function in SQL. -#' -#' @rdname spark_partition_id -#' @name spark_partition_id -#' @aliases spark_partition_id,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases spark_partition_id spark_partition_id,missing-method #' @export #' @examples -#' \dontrun{select(df, spark_partition_id())} +#' +#' \dontrun{head(select(df, spark_partition_id()))} #' @note spark_partition_id since 2.0.0 setMethod("spark_partition_id", signature("missing"), @@ -1527,9 +1487,11 @@ setMethod("spark_partition_id", column(jc) }) -#' @rdname sd -#' @aliases stddev,Column-method -#' @name stddev +#' @details +#' \code{stddev}: Alias for \code{std_dev}. +#' +#' @rdname column_aggregate_functions +#' @aliases stddev stddev,Column-method #' @note stddev since 1.6.0 setMethod("stddev", signature(x = "Column"), @@ -1538,19 +1500,12 @@ setMethod("stddev", column(jc) }) -#' stddev_pop +#' @details +#' \code{stddev_pop}: Returns the population standard deviation of the expression in a group. #' -#' Aggregate function: returns the population standard deviation of the expression in a group. -#' -#' @param x Column to compute on. -#' -#' @rdname stddev_pop -#' @name stddev_pop -#' @family agg_funcs -#' @aliases stddev_pop,Column-method -#' @seealso \link{sd}, \link{stddev_samp} +#' @rdname column_aggregate_functions +#' @aliases stddev_pop stddev_pop,Column-method #' @export -#' @examples \dontrun{stddev_pop(df$c)} #' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), @@ -1559,19 +1514,12 @@ setMethod("stddev_pop", column(jc) }) -#' stddev_samp -#' -#' Aggregate function: returns the unbiased sample standard deviation of the expression in a group. +#' @details +#' \code{stddev_samp}: Returns the unbiased sample standard deviation of the expression in a group. #' -#' @param x Column to compute on. -#' -#' @rdname stddev_samp -#' @name stddev_samp -#' @family agg_funcs -#' @aliases stddev_samp,Column-method -#' @seealso \link{stddev_pop}, \link{sd} +#' @rdname column_aggregate_functions +#' @aliases stddev_samp stddev_samp,Column-method #' @export -#' @examples \dontrun{stddev_samp(df$c)} #' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), @@ -1580,23 +1528,19 @@ setMethod("stddev_samp", column(jc) }) -#' struct -#' -#' Creates a new struct column that composes multiple input columns. -#' -#' @param x a column to compute on. -#' @param ... optional column(s) to be included. +#' @details +#' \code{struct}: Creates a new struct column that composes multiple input columns. #' -#' @rdname struct -#' @name struct -#' @family normal_funcs -#' @aliases struct,characterOrColumn-method +#' @rdname column_nonaggregate_functions +#' @aliases struct struct,characterOrColumn-method #' @export #' @examples +#' #' \dontrun{ -#' struct(df$c, df$d) -#' struct("col1", "col2") -#' } +#' tmp <- mutate(df, v1 = struct(df$mpg, df$cyl), v2 = struct("hp", "wt", "vs"), +#' v3 = create_array(df$mpg, df$cyl, df$hp), +#' v4 = create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))) +#' head(tmp)} #' @note struct since 1.6.0 setMethod("struct", signature(x = "characterOrColumn"), @@ -1610,18 +1554,12 @@ setMethod("struct", column(jc) }) -#' sqrt -#' -#' Computes the square root of the specified float value. +#' @details +#' \code{sqrt}: Computes the square root of the specified float value. #' -#' @param x Column to compute on. -#' -#' @rdname sqrt -#' @name sqrt -#' @family math_funcs -#' @aliases sqrt,Column-method +#' @rdname column_math_functions +#' @aliases sqrt sqrt,Column-method #' @export -#' @examples \dontrun{sqrt(df$c)} #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1630,18 +1568,12 @@ setMethod("sqrt", column(jc) }) -#' sum +#' @details +#' \code{sum}: Returns the sum of all values in the expression. #' -#' Aggregate function: returns the sum of all values in the expression. -#' -#' @param x Column to compute on. -#' -#' @rdname sum -#' @name sum -#' @family agg_funcs -#' @aliases sum,Column-method +#' @rdname column_aggregate_functions +#' @aliases sum sum,Column-method #' @export -#' @examples \dontrun{sum(df$c)} #' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), @@ -1650,18 +1582,17 @@ setMethod("sum", column(jc) }) -#' sumDistinct -#' -#' Aggregate function: returns the sum of distinct values in the expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{sumDistinct}: Returns the sum of distinct values in the expression. #' -#' @rdname sumDistinct -#' @name sumDistinct -#' @family agg_funcs -#' @aliases sumDistinct,Column-method +#' @rdname column_aggregate_functions +#' @aliases sumDistinct sumDistinct,Column-method #' @export -#' @examples \dontrun{sumDistinct(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, sumDistinct(df$gear))) +#' head(distinct(select(df, "gear")))} #' @note sumDistinct since 1.4.0 setMethod("sumDistinct", signature(x = "Column"), @@ -1670,18 +1601,12 @@ setMethod("sumDistinct", column(jc) }) -#' tan -#' -#' Computes the tangent of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{tan}: Computes the tangent of the given value. Units in radians. #' -#' @rdname tan -#' @name tan -#' @family math_funcs -#' @aliases tan,Column-method +#' @rdname column_math_functions +#' @aliases tan tan,Column-method #' @export -#' @examples \dontrun{tan(df$c)} #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1690,18 +1615,12 @@ setMethod("tan", column(jc) }) -#' tanh +#' @details +#' \code{tanh}: Computes the hyperbolic tangent of the given value. #' -#' Computes the hyperbolic tangent of the given value. -#' -#' @param x Column to compute on. -#' -#' @rdname tanh -#' @name tanh -#' @family math_funcs -#' @aliases tanh,Column-method +#' @rdname column_math_functions +#' @aliases tanh tanh,Column-method #' @export -#' @examples \dontrun{tanh(df$c)} #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1710,18 +1629,13 @@ setMethod("tanh", column(jc) }) -#' toDegrees -#' -#' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. +#' @details +#' \code{toDegrees}: Converts an angle measured in radians to an approximately equivalent angle +#' measured in degrees. #' -#' @param x Column to compute on. -#' -#' @rdname toDegrees -#' @name toDegrees -#' @family math_funcs -#' @aliases toDegrees,Column-method +#' @rdname column_math_functions +#' @aliases toDegrees toDegrees,Column-method #' @export -#' @examples \dontrun{toDegrees(df$c)} #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1730,18 +1644,13 @@ setMethod("toDegrees", column(jc) }) -#' toRadians +#' @details +#' \code{toRadians}: Converts an angle measured in degrees to an approximately equivalent angle +#' measured in radians. #' -#' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. -#' -#' @param x Column to compute on. -#' -#' @rdname toRadians -#' @name toRadians -#' @family math_funcs -#' @aliases toRadians,Column-method +#' @rdname column_math_functions +#' @aliases toRadians toRadians,Column-method #' @export -#' @examples \dontrun{toRadians(df$c)} #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1750,28 +1659,28 @@ setMethod("toRadians", column(jc) }) -#' to_date -#' -#' Converts the column into a DateType. You may optionally specify a format -#' according to the rules in: +#' @details +#' \code{to_date}: Converts the column into a DateType. You may optionally specify +#' a format according to the rules in: #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. -#' The default format is 'yyyy-MM-dd'. +#' By default, it follows casting rules to a DateType if the format is omitted +#' (equivalent to \code{cast(df$x, "date")}). #' -#' @param x Column to parse. -#' @param format string to use to parse x Column to DateType. (optional) -#' -#' @rdname to_date -#' @name to_date -#' @family datetime_funcs -#' @aliases to_date,Column,missing-method +#' @rdname column_datetime_functions +#' @aliases to_date to_date,Column,missing-method #' @export #' @examples +#' #' \dontrun{ -#' to_date(df$c) -#' to_date(df$c, 'yyyy-MM-dd') -#' } +#' tmp <- createDataFrame(data.frame(time_string = dts)) +#' tmp2 <- mutate(tmp, date1 = to_date(tmp$time_string), +#' date2 = to_date(tmp$time_string, "yyyy-MM-dd"), +#' date3 = date_format(tmp$time_string, "MM/dd/yyy"), +#' time1 = to_timestamp(tmp$time_string), +#' time2 = to_timestamp(tmp$time_string, "yyyy-MM-dd")) +#' head(tmp2)} #' @note to_date(Column) since 1.5.0 setMethod("to_date", signature(x = "Column", format = "missing"), @@ -1780,9 +1689,7 @@ setMethod("to_date", column(jc) }) -#' @rdname to_date -#' @name to_date -#' @family datetime_funcs +#' @rdname column_datetime_functions #' @aliases to_date,Column,character-method #' @export #' @note to_date(Column, character) since 2.2.0 @@ -1793,30 +1700,32 @@ setMethod("to_date", column(jc) }) -#' to_json -#' -#' Converts a column containing a \code{structType} or array of \code{structType} into a Column -#' of JSON string. Resolving the Column can fail if an unsupported type is encountered. -#' -#' @param x Column containing the struct or array of the structs -#' @param ... additional named properties to control how it is converted, accepts the same options -#' as the JSON data source. +#' @details +#' \code{to_json}: Converts a column containing a \code{structType}, array of \code{structType}, +#' a \code{mapType} or array of \code{mapType} into a Column of JSON string. +#' Resolving the Column can fail if an unsupported type is encountered. #' -#' @family normal_funcs -#' @rdname to_json -#' @name to_json -#' @aliases to_json,Column-method +#' @rdname column_collection_functions +#' @aliases to_json to_json,Column-method #' @export #' @examples +#' #' \dontrun{ #' # Converts a struct into a JSON object -#' df <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") -#' select(df, to_json(df$d, dateFormat = 'dd/MM/yyyy')) +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df2, to_json(df2$d, dateFormat = 'dd/MM/yyyy')) #' #' # Converts an array of structs into a JSON array -#' df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") -#' select(df, to_json(df$people)) -#'} +#' df2 <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' +#' # Converts a map into a JSON object +#' df2 <- sql("SELECT map('name', 'Bob')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' +#' # Converts an array of maps into a JSON array +#' df2 <- sql("SELECT array(map('name', 'Bob'), map('name', 'Alice')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people))} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), function(x, ...) { @@ -1825,28 +1734,18 @@ setMethod("to_json", signature(x = "Column"), column(jc) }) -#' to_timestamp -#' -#' Converts the column into a TimestampType. You may optionally specify a format -#' according to the rules in: +#' @details +#' \code{to_timestamp}: Converts the column into a TimestampType. You may optionally specify +#' a format according to the rules in: #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. -#' The default format is 'yyyy-MM-dd HH:mm:ss'. +#' By default, it follows casting rules to a TimestampType if the format is omitted +#' (equivalent to \code{cast(df$x, "timestamp")}). #' -#' @param x Column to parse. -#' @param format string to use to parse x Column to DateType. (optional) -#' -#' @rdname to_timestamp -#' @name to_timestamp -#' @family datetime_funcs -#' @aliases to_timestamp,Column,missing-method +#' @rdname column_datetime_functions +#' @aliases to_timestamp to_timestamp,Column,missing-method #' @export -#' @examples -#' \dontrun{ -#' to_timestamp(df$c) -#' to_timestamp(df$c, 'yyyy-MM-dd') -#' } #' @note to_timestamp(Column) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "missing"), @@ -1855,9 +1754,7 @@ setMethod("to_timestamp", column(jc) }) -#' @rdname to_timestamp -#' @name to_timestamp -#' @family datetime_funcs +#' @rdname column_datetime_functions #' @aliases to_timestamp,Column,character-method #' @export #' @note to_timestamp(Column, character) since 2.2.0 @@ -1868,18 +1765,12 @@ setMethod("to_timestamp", column(jc) }) -#' trim +#' @details +#' \code{trim}: Trims the spaces from both ends for the specified string column. #' -#' Trim the spaces from both ends for the specified string column. -#' -#' @param x Column to compute on. -#' -#' @rdname trim -#' @name trim -#' @family string_funcs -#' @aliases trim,Column-method +#' @rdname column_string_functions +#' @aliases trim trim,Column-method #' @export -#' @examples \dontrun{trim(df$c)} #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column"), @@ -1888,19 +1779,13 @@ setMethod("trim", column(jc) }) -#' unbase64 -#' -#' Decodes a BASE64 encoded string column and returns it as a binary column. +#' @details +#' \code{unbase64}: Decodes a BASE64 encoded string column and returns it as a binary column. #' This is the reverse of base64. #' -#' @param x Column to compute on. -#' -#' @rdname unbase64 -#' @name unbase64 -#' @family string_funcs -#' @aliases unbase64,Column-method +#' @rdname column_string_functions +#' @aliases unbase64 unbase64,Column-method #' @export -#' @examples \dontrun{unbase64(df$c)} #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1909,19 +1794,13 @@ setMethod("unbase64", column(jc) }) -#' unhex -#' -#' Inverse of hex. Interprets each pair of characters as a hexadecimal number +#' @details +#' \code{unhex}: Inverse of hex. Interprets each pair of characters as a hexadecimal number #' and converts to the byte representation of number. #' -#' @param x Column to compute on. -#' -#' @rdname unhex -#' @name unhex -#' @family math_funcs -#' @aliases unhex,Column-method +#' @rdname column_math_functions +#' @aliases unhex unhex,Column-method #' @export -#' @examples \dontrun{unhex(df$c)} #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -1930,18 +1809,12 @@ setMethod("unhex", column(jc) }) -#' upper -#' -#' Converts a string column to upper case. -#' -#' @param x Column to compute on. +#' @details +#' \code{upper}: Converts a string column to upper case. #' -#' @rdname upper -#' @name upper -#' @family string_funcs -#' @aliases upper,Column-method +#' @rdname column_string_functions +#' @aliases upper upper,Column-method #' @export -#' @examples \dontrun{upper(df$c)} #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1950,24 +1823,16 @@ setMethod("upper", column(jc) }) -#' var -#' -#' Aggregate function: alias for \link{var_samp}. +#' @details +#' \code{var}: Alias for \code{var_samp}. #' -#' @param x a Column to compute on. -#' @param y,na.rm,use currently not used. -#' @rdname var -#' @name var -#' @family agg_funcs -#' @aliases var,Column-method -#' @seealso \link{var_pop}, \link{var_samp} +#' @rdname column_aggregate_functions +#' @aliases var var,Column-method #' @export #' @examples +#' #'\dontrun{ -#'variance(df$c) -#'select(df, var_pop(df$age)) -#'agg(df, var(df$age)) -#'} +#'head(agg(df, var(df$mpg), variance(df$mpg), var_pop(df$mpg), var_samp(df$mpg)))} #' @note var since 1.6.0 setMethod("var", signature(x = "Column"), @@ -1976,9 +1841,9 @@ setMethod("var", var_samp(x) }) -#' @rdname var -#' @aliases variance,Column-method -#' @name variance +#' @rdname column_aggregate_functions +#' @aliases variance variance,Column-method +#' @export #' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), @@ -1987,19 +1852,12 @@ setMethod("variance", column(jc) }) -#' var_pop -#' -#' Aggregate function: returns the population variance of the values in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{var_pop}: Returns the population variance of the values in a group. #' -#' @rdname var_pop -#' @name var_pop -#' @family agg_funcs -#' @aliases var_pop,Column-method -#' @seealso \link{var}, \link{var_samp} +#' @rdname column_aggregate_functions +#' @aliases var_pop var_pop,Column-method #' @export -#' @examples \dontrun{var_pop(df$c)} #' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), @@ -2008,19 +1866,12 @@ setMethod("var_pop", column(jc) }) -#' var_samp -#' -#' Aggregate function: returns the unbiased variance of the values in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{var_samp}: Returns the unbiased variance of the values in a group. #' -#' @rdname var_samp -#' @name var_samp -#' @aliases var_samp,Column-method -#' @family agg_funcs -#' @seealso \link{var_pop}, \link{var} +#' @rdname column_aggregate_functions +#' @aliases var_samp var_samp,Column-method #' @export -#' @examples \dontrun{var_samp(df$c)} #' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), @@ -2029,18 +1880,12 @@ setMethod("var_samp", column(jc) }) -#' weekofyear +#' @details +#' \code{weekofyear}: Extracts the week number as an integer from a given date/timestamp/string. #' -#' Extracts the week number as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. -#' -#' @rdname weekofyear -#' @name weekofyear -#' @aliases weekofyear,Column-method -#' @family datetime_funcs +#' @rdname column_datetime_functions +#' @aliases weekofyear weekofyear,Column-method #' @export -#' @examples \dontrun{weekofyear(df$c)} #' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), @@ -2049,18 +1894,12 @@ setMethod("weekofyear", column(jc) }) -#' year -#' -#' Extracts the year as an integer from a given date/timestamp/string. +#' @details +#' \code{year}: Extracts the year as an integer from a given date/timestamp/string. #' -#' @param x Column to compute on. -#' -#' @rdname year -#' @name year -#' @family datetime_funcs -#' @aliases year,Column-method +#' @rdname column_datetime_functions +#' @aliases year year,Column-method #' @export -#' @examples \dontrun{year(df$c)} #' @note year since 1.5.0 setMethod("year", signature(x = "Column"), @@ -2069,20 +1908,13 @@ setMethod("year", column(jc) }) -#' atan2 -#' -#' Returns the angle theta from the conversion of rectangular coordinates (x, y) to -#' polar coordinates (r, theta). -# -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates +#' (x, y) to polar coordinates (r, theta). Units in radians. #' -#' @rdname atan2 -#' @name atan2 -#' @family math_funcs -#' @aliases atan2,Column-method +#' @rdname column_math_functions +#' @aliases atan2 atan2,Column-method #' @export -#' @examples \dontrun{atan2(df$c, x)} #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -2093,19 +1925,20 @@ setMethod("atan2", signature(y = "Column"), column(jc) }) -#' datediff -#' -#' Returns the number of days from \code{start} to \code{end}. +#' @details +#' \code{datediff}: Returns the number of days from \code{y} to \code{x}. #' -#' @param x start Column to use. -#' @param y end Column to use. -#' -#' @rdname datediff -#' @name datediff -#' @aliases datediff,Column-method -#' @family datetime_funcs +#' @rdname column_datetime_diff_functions +#' @aliases datediff datediff,Column-method #' @export -#' @examples \dontrun{datediff(df$c, x)} +#' @examples +#' +#' \dontrun{ +#' tmp <- createDataFrame(data.frame(time_string1 = as.POSIXct(dts), +#' time_string2 = as.POSIXct(dts[order(runif(length(dts)))]))) +#' tmp2 <- mutate(tmp, datediff = datediff(tmp$time_string1, tmp$time_string2), +#' monthdiff = months_between(tmp$time_string1, tmp$time_string2)) +#' head(tmp2)} #' @note datediff since 1.5.0 setMethod("datediff", signature(y = "Column"), function(y, x) { @@ -2116,19 +1949,12 @@ setMethod("datediff", signature(y = "Column"), column(jc) }) -#' hypot -#' -#' Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. -# -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{hypot}: Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. #' -#' @rdname hypot -#' @name hypot -#' @family math_funcs -#' @aliases hypot,Column-method +#' @rdname column_math_functions +#' @aliases hypot hypot,Column-method #' @export -#' @examples \dontrun{hypot(df$c, x)} #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2139,19 +1965,19 @@ setMethod("hypot", signature(y = "Column"), column(jc) }) -#' levenshtein -#' -#' Computes the Levenshtein distance of the two given string columns. +#' @details +#' \code{levenshtein}: Computes the Levenshtein distance of the two given string columns. #' -#' @param x Column to compute on. -#' @param y Column to compute on. -#' -#' @rdname levenshtein -#' @name levenshtein -#' @family string_funcs -#' @aliases levenshtein,Column-method +#' @rdname column_string_functions +#' @aliases levenshtein levenshtein,Column-method #' @export -#' @examples \dontrun{levenshtein(df$c, x)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, d1 = levenshtein(df$Class, df$Sex), +#' d2 = levenshtein(df$Age, df$Sex), +#' d3 = levenshtein(df$Age, df$Age)) +#' head(tmp)} #' @note levenshtein since 1.5.0 setMethod("levenshtein", signature(y = "Column"), function(y, x) { @@ -2162,19 +1988,12 @@ setMethod("levenshtein", signature(y = "Column"), column(jc) }) -#' months_between +#' @details +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. #' -#' Returns number of months between dates \code{date1} and \code{date2}. -#' -#' @param x start Column to use. -#' @param y end Column to use. -#' -#' @rdname months_between -#' @name months_between -#' @family datetime_funcs -#' @aliases months_between,Column-method +#' @rdname column_datetime_diff_functions +#' @aliases months_between months_between,Column-method #' @export -#' @examples \dontrun{months_between(df$c, x)} #' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { @@ -2185,20 +2004,13 @@ setMethod("months_between", signature(y = "Column"), column(jc) }) -#' nanvl -#' -#' Returns col1 if it is not NaN, or col2 if col1 is NaN. -#' Both inputs should be floating point columns (DoubleType or FloatType). +#' @details +#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column (\code{x}) if +#' the first column is NaN. Both inputs should be floating point columns (DoubleType or FloatType). #' -#' @param x first Column. -#' @param y second Column. -#' -#' @rdname nanvl -#' @name nanvl -#' @family normal_funcs -#' @aliases nanvl,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases nanvl nanvl,Column-method #' @export -#' @examples \dontrun{nanvl(df$c, x)} #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2209,20 +2021,13 @@ setMethod("nanvl", signature(y = "Column"), column(jc) }) -#' pmod -#' -#' Returns the positive value of dividend mod divisor. +#' @details +#' \code{pmod}: Returns the positive value of dividend mod divisor. +#' Column \code{x} is divisor column, and column \code{y} is the dividend column. #' -#' @param x divisor Column. -#' @param y dividend Column. -#' -#' @rdname pmod -#' @name pmod -#' @docType methods -#' @family math_funcs -#' @aliases pmod,Column-method +#' @rdname column_math_functions +#' @aliases pmod pmod,Column-method #' @export -#' @examples \dontrun{pmod(df$c, x)} #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2233,17 +2038,11 @@ setMethod("pmod", signature(y = "Column"), column(jc) }) - -#' @rdname approxCountDistinct -#' @name approxCountDistinct -#' -#' @param x Column to compute on. -#' @param rsd maximum estimation error allowed (default = 0.05) -#' @param ... further arguments to be passed to or from other methods. +#' @param rsd maximum estimation error allowed (default = 0.05). #' +#' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method #' @export -#' @examples \dontrun{approxCountDistinct(df$c, 0.02)} #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -2252,18 +2051,12 @@ setMethod("approxCountDistinct", column(jc) }) -#' Count Distinct Values +#' @details +#' \code{countDistinct}: Returns the number of distinct items in a group. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family agg_funcs -#' @rdname countDistinct -#' @name countDistinct -#' @aliases countDistinct,Column-method -#' @return the number of distinct items in a group. +#' @rdname column_aggregate_functions +#' @aliases countDistinct countDistinct,Column-method #' @export -#' @examples \dontrun{countDistinct(df$c)} #' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), @@ -2277,20 +2070,22 @@ setMethod("countDistinct", column(jc) }) - -#' concat -#' -#' Concatenates multiple input string columns together into a single string column. +#' @details +#' \code{concat}: Concatenates multiple input string columns together into a single string column. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family string_funcs -#' @rdname concat -#' @name concat -#' @aliases concat,Column-method +#' @rdname column_string_functions +#' @aliases concat concat,Column-method #' @export -#' @examples \dontrun{concat(df$strings, df$strings2)} +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), +#' s2 = concat(df$Class, df$Sex, df$Age), +#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), +#' s4 = concat_ws("_", df$Class, df$Sex), +#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2303,20 +2098,13 @@ setMethod("concat", column(jc) }) -#' greatest -#' -#' Returns the greatest value of the list of column names, skipping null values. +#' @details +#' \code{greatest}: Returns the greatest value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family normal_funcs -#' @rdname greatest -#' @name greatest -#' @aliases greatest,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases greatest greatest,Column-method #' @export -#' @examples \dontrun{greatest(df$c, df$d)} #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2330,20 +2118,13 @@ setMethod("greatest", column(jc) }) -#' least -#' -#' Returns the least value of the list of column names, skipping null values. +#' @details +#' \code{least}: Returns the least value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family normal_funcs -#' @rdname least -#' @aliases least,Column-method -#' @name least +#' @rdname column_nonaggregate_functions +#' @aliases least least,Column-method #' @export -#' @examples \dontrun{least(df$c, df$d)} #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2357,40 +2138,12 @@ setMethod("least", column(jc) }) -#' @rdname ceil +#' @details +#' \code{n_distinct}: Returns the number of distinct items in a group. #' -#' @name ceiling -#' @aliases ceiling,Column-method +#' @rdname column_aggregate_functions +#' @aliases n_distinct n_distinct,Column-method #' @export -#' @examples \dontrun{ceiling(df$c)} -#' @note ceiling since 1.5.0 -setMethod("ceiling", - signature(x = "Column"), - function(x) { - ceil(x) - }) - -#' @rdname sign -#' -#' @name sign -#' @aliases sign,Column-method -#' @export -#' @examples \dontrun{sign(df$c)} -#' @note sign since 1.5.0 -setMethod("sign", signature(x = "Column"), - function(x) { - signum(x) - }) - -#' n_distinct -#' -#' Aggregate function: returns the number of distinct items in a group. -#' -#' @rdname countDistinct -#' @name n_distinct -#' @aliases n_distinct,Column-method -#' @export -#' @examples \dontrun{n_distinct(df$c)} #' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { @@ -2407,27 +2160,19 @@ setMethod("n", signature(x = "Column"), function(x) { count(x) }) - -#' date_format -#' -#' Converts a date/timestamp/string to a value of string in the format specified by the date -#' format given by the second argument. -#' -#' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All + +#' @details +#' \code{date_format}: Converts a date/timestamp/string to a value of string in the format +#' specified by the date format given by the second argument. A pattern could be for instance +#' \code{dd.MM.yyyy} and could return a string like '18.03.1993'. All #' pattern letters of \code{java.text.SimpleDateFormat} can be used. -#' #' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' -#' @param y Column to compute on. -#' @param x date format specification. +#' @rdname column_datetime_diff_functions #' -#' @family datetime_funcs -#' @rdname date_format -#' @name date_format -#' @aliases date_format,Column,character-method +#' @aliases date_format date_format,Column,character-method #' @export -#' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} #' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { @@ -2435,31 +2180,37 @@ setMethod("date_format", signature(y = "Column", x = "character"), column(jc) }) -#' from_json -#' -#' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema} or array of \code{structType} if \code{as.json.array} is set to \code{TRUE}. -#' If the string is unparseable, the Column will contains the value NA. +#' @details +#' \code{from_json}: Parses a column containing a JSON string into a Column of \code{structType} +#' with the specified \code{schema} or array of \code{structType} if \code{as.json.array} is set +#' to \code{TRUE}. If the string is unparseable, the Column will contain the value NA. #' -#' @param x Column containing the JSON string. +#' @rdname column_collection_functions #' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. -#' @param ... additional named properties to control how the json is parsed, accepts the same -#' options as the JSON data source. -#' -#' @family normal_funcs -#' @rdname from_json -#' @name from_json -#' @aliases from_json,Column,structType-method +#' @aliases from_json from_json,Column,characterOrstructType-method #' @export #' @examples +#' #' \dontrun{ -#' schema <- structType(structField("name", "string"), -#' select(df, from_json(df$value, schema, dateFormat = "dd/MM/yyyy")) -#'} +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' df2 <- mutate(df2, d2 = to_json(df2$d, dateFormat = 'dd/MM/yyyy')) +#' schema <- structType(structField("date", "string")) +#' head(select(df2, from_json(df2$d2, schema, dateFormat = 'dd/MM/yyyy'))) + +#' df2 <- sql("SELECT named_struct('name', 'Bob') as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' schema <- structType(structField("name", "string")) +#' head(select(df2, from_json(df2$people_json, schema))) +#' head(select(df2, from_json(df2$people_json, "name STRING")))} #' @note from_json since 2.2.0 -setMethod("from_json", signature(x = "Column", schema = "structType"), +setMethod("from_json", signature(x = "Column", schema = "characterOrstructType"), function(x, schema, as.json.array = FALSE, ...) { + if (is.character(schema)) { + schema <- structType(schema) + } + if (as.json.array) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", @@ -2474,20 +2225,21 @@ setMethod("from_json", signature(x = "Column", schema = "structType"), column(jc) }) -#' from_utc_timestamp -#' -#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp -#' that corresponds to the same time of day in the given timezone. +#' @details +#' \code{from_utc_timestamp}: Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a +#' time in UTC, and renders that time as a timestamp in the given time zone. For example, 'GMT+1' +#' would yield '2017-07-14 03:40:00.0'. #' -#' @param y Column to compute on. -#' @param x time zone to use. +#' @rdname column_datetime_diff_functions #' -#' @family datetime_funcs -#' @rdname from_utc_timestamp -#' @name from_utc_timestamp -#' @aliases from_utc_timestamp,Column,character-method +#' @aliases from_utc_timestamp from_utc_timestamp,Column,character-method #' @export -#' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, "PST"), +#' to_utc = to_utc_timestamp(df$time, "PST")) +#' head(tmp)} #' @note from_utc_timestamp since 1.5.0 setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2495,22 +2247,21 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), column(jc) }) -#' instr -#' -#' Locate the position of the first occurrence of substr column in the given string. -#' Returns null if either of the arguments are null. -#' -#' Note: The position is not zero based, but 1 based index. Returns 0 if substr -#' could not be found in str. +#' @details +#' \code{instr}: Locates the position of the first occurrence of a substring (\code{x}) +#' in the given string column (\code{y}). Returns null if either of the arguments are null. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the substring +#' could not be found in the string column. #' -#' @param y column to check -#' @param x substring to check -#' @family string_funcs -#' @aliases instr,Column,character-method -#' @rdname instr -#' @name instr +#' @rdname column_string_functions +#' @aliases instr instr,Column,character-method #' @export -#' @examples \dontrun{instr(df$c, 'b')} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = instr(df$Sex, "m"), s2 = instr(df$Sex, "M"), +#' s3 = locate("m", df$Sex), s4 = locate("m", df$Sex, pos = 4)) +#' head(tmp)} #' @note instr since 1.5.0 setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { @@ -2518,30 +2269,16 @@ setMethod("instr", signature(y = "Column", x = "character"), column(jc) }) -#' next_day -#' -#' Given a date column, returns the first date which is later than the value of the date column -#' that is on the specified day of the week. -#' -#' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first -#' Sunday after 2015-07-27. -#' -#' Day of the week parameter is case insensitive, and accepts first three or two characters: -#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". +#' @details +#' \code{next_day}: Given a date column, returns the first date which is later than the value of +#' the date column that is on the specified day of the week. For example, +#' \code{next_day("2015-07-27", "Sunday")} returns 2015-08-02 because that is the first Sunday +#' after 2015-07-27. Day of the week parameter is case insensitive, and accepts first three or +#' two characters: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' -#' @param y Column to compute on. -#' @param x Day of the week string. -#' -#' @family datetime_funcs -#' @rdname next_day -#' @name next_day -#' @aliases next_day,Column,character-method +#' @rdname column_datetime_diff_functions +#' @aliases next_day next_day,Column,character-method #' @export -#' @examples -#'\dontrun{ -#'next_day(df$d, 'Sun') -#'next_day(df$d, 'Sunday') -#'} #' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { @@ -2549,20 +2286,14 @@ setMethod("next_day", signature(y = "Column", x = "character"), column(jc) }) -#' to_utc_timestamp +#' @details +#' \code{to_utc_timestamp}: Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a +#' time in the given time zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' +#' would yield '2017-07-14 01:40:00.0'. #' -#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns -#' another timestamp that corresponds to the same time of day in UTC. -#' -#' @param y Column to compute on -#' @param x timezone to use -#' -#' @family datetime_funcs -#' @rdname to_utc_timestamp -#' @name to_utc_timestamp -#' @aliases to_utc_timestamp,Column,character-method +#' @rdname column_datetime_diff_functions +#' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method #' @export -#' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} #' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2570,19 +2301,20 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), column(jc) }) -#' add_months -#' -#' Returns the date that is numMonths after startDate. -#' -#' @param y Column to compute on -#' @param x Number of months to add +#' @details +#' \code{add_months}: Returns the date that is numMonths (\code{x}) after startDate (\code{y}). #' -#' @name add_months -#' @family datetime_funcs -#' @rdname add_months -#' @aliases add_months,Column,numeric-method +#' @rdname column_datetime_diff_functions +#' @aliases add_months add_months,Column,numeric-method #' @export -#' @examples \dontrun{add_months(df$d, 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, t1 = add_months(df$time, 1), +#' t2 = date_add(df$time, 2), +#' t3 = date_sub(df$time, 3), +#' t4 = next_day(df$time, "Sun")) +#' head(tmp)} #' @note add_months since 1.5.0 setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2590,19 +2322,12 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), column(jc) }) -#' date_add -#' -#' Returns the date that is \code{x} days after +#' @details +#' \code{date_add}: Returns the date that is \code{x} days after. #' -#' @param y Column to compute on -#' @param x Number of days to add -#' -#' @family datetime_funcs -#' @rdname date_add -#' @name date_add -#' @aliases date_add,Column,numeric-method +#' @rdname column_datetime_diff_functions +#' @aliases date_add date_add,Column,numeric-method #' @export -#' @examples \dontrun{date_add(df$d, 1)} #' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2610,19 +2335,13 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), column(jc) }) -#' date_sub -#' -#' Returns the date that is \code{x} days before +#' @details +#' \code{date_sub}: Returns the date that is \code{x} days before. #' -#' @param y Column to compute on -#' @param x Number of days to substract +#' @rdname column_datetime_diff_functions #' -#' @family datetime_funcs -#' @rdname date_sub -#' @name date_sub -#' @aliases date_sub,Column,numeric-method +#' @aliases date_sub date_sub,Column,numeric-method #' @export -#' @examples \dontrun{date_sub(df$d, 1)} #' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2630,22 +2349,22 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), column(jc) }) -#' format_number -#' -#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places -#' with HALF_EVEN round mode, and returns the result as a string column. -#' -#' If x is 0, the result has no decimal point or fractional part. -#' If x < 0, the result will be null. +#' @details +#' \code{format_number}: Formats numeric column \code{y} to a format like '#,###,###.##', +#' rounded to \code{x} decimal places with HALF_EVEN round mode, and returns the result +#' as a string column. +#' If \code{x} is 0, the result has no decimal point or fractional part. +#' If \code{x} < 0, the result will be null. #' -#' @param y column to format -#' @param x number of decimal place to format to -#' @family string_funcs -#' @rdname format_number -#' @name format_number -#' @aliases format_number,Column,numeric-method +#' @rdname column_string_functions +#' @aliases format_number format_number,Column,numeric-method #' @export -#' @examples \dontrun{format_number(df$n, 4)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, v1 = df$Freq/3) +#' head(select(tmp, format_number(tmp$v1, 0), format_number(tmp$v1, 2), +#' format_string("%4.2f %s", tmp$v1, tmp$Sex)), 10)} #' @note format_number since 1.5.0 setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2655,19 +2374,14 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), column(jc) }) -#' sha2 -#' -#' Calculates the SHA-2 family of hash functions of a binary column and -#' returns the value as a hex string. +#' @details +#' \code{sha2}: Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. The second argument \code{x} specifies the number +#' of bits, and is one of 224, 256, 384, or 512. #' -#' @param y column to compute SHA-2 on. -#' @param x one of 224, 256, 384, or 512. -#' @family misc_funcs -#' @rdname sha2 -#' @name sha2 -#' @aliases sha2,Column,numeric-method +#' @rdname column_misc_functions +#' @aliases sha2 sha2,Column,numeric-method #' @export -#' @examples \dontrun{sha2(df$c, 256)} #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2675,20 +2389,13 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftLeft +#' @details +#' \code{shiftLeft}: Shifts the given value numBits left. If the given value is a long value, +#' this function will return a long value else it will return an integer value. #' -#' Shift the given value numBits left. If the given value is a long value, this function -#' will return a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. -#' -#' @family math_funcs -#' @rdname shiftLeft -#' @name shiftLeft -#' @aliases shiftLeft,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftLeft shiftLeft,Column,numeric-method #' @export -#' @examples \dontrun{shiftLeft(df$c, 1)} #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2698,20 +2405,13 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftRight -#' -#' (Signed) shift the given value numBits right. If the given value is a long value, it will return -#' a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. +#' @details +#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. #' -#' @family math_funcs -#' @rdname shiftRight -#' @name shiftRight -#' @aliases shiftRight,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftRight shiftRight,Column,numeric-method #' @export -#' @examples \dontrun{shiftRight(df$c, 1)} #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2721,20 +2421,13 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftRightUnsigned -#' -#' Unsigned shift the given value numBits right. If the given value is a long value, -#' it will return a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. +#' @details +#' \code{shiftRightUnsigned}: (Unigned) shifts the given value numBits right. If the given value is +#' a long value, it will return a long value else it will return an integer value. #' -#' @family math_funcs -#' @rdname shiftRightUnsigned -#' @name shiftRightUnsigned -#' @aliases shiftRightUnsigned,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method #' @export -#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2744,21 +2437,14 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), column(jc) }) -#' concat_ws +#' @details +#' \code{concat_ws}: Concatenates multiple input string columns together into a single +#' string column, using the given separator. #' -#' Concatenates multiple input string columns together into a single string column, -#' using the given separator. -#' -#' @param x column to concatenate. #' @param sep separator to use. -#' @param ... other columns to concatenate. -#' -#' @family string_funcs -#' @rdname concat_ws -#' @name concat_ws -#' @aliases concat_ws,character,Column-method +#' @rdname column_string_functions +#' @aliases concat_ws concat_ws,character,Column-method #' @export -#' @examples \dontrun{concat_ws('-', df$s, df$d)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2767,20 +2453,14 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), column(jc) }) -#' conv +#' @details +#' \code{conv}: Converts a number in a string column from one base to another. #' -#' Convert a number in a string column from one base to another. -#' -#' @param x column to convert. #' @param fromBase base to convert from. #' @param toBase base to convert to. -#' -#' @family math_funcs -#' @rdname conv -#' @aliases conv,Column,numeric,numeric-method -#' @name conv +#' @rdname column_math_functions +#' @aliases conv conv,Column,numeric,numeric-method #' @export -#' @examples \dontrun{conv(df$n, 2, 16)} #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { @@ -2792,18 +2472,13 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri column(jc) }) -#' expr +#' @details +#' \code{expr}: Parses the expression string into the column that it represents, similar to +#' \code{SparkDataFrame.selectExpr} #' -#' Parses the expression string into the column that it represents, similar to -#' SparkDataFrame.selectExpr -#' -#' @param x an expression character object to be parsed. -#' @family normal_funcs -#' @rdname expr -#' @aliases expr,character-method -#' @name expr +#' @rdname column_nonaggregate_functions +#' @aliases expr expr,character-method #' @export -#' @examples \dontrun{expr('length(name)')} #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2811,19 +2486,14 @@ setMethod("expr", signature(x = "character"), column(jc) }) -#' format_string -#' -#' Formats the arguments in printf-style and returns the result as a string column. +#' @details +#' \code{format_string}: Formats the arguments in printf-style and returns the result +#' as a string column. #' #' @param format a character object of format strings. -#' @param x a Column. -#' @param ... additional Column(s). -#' @family string_funcs -#' @rdname format_string -#' @name format_string -#' @aliases format_string,character,Column-method -#' @export -#' @examples \dontrun{format_string('%d %s', df$a, df$b)} +#' @rdname column_string_functions +#' @aliases format_string format_string,character,Column-method +#' @export #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2834,27 +2504,24 @@ setMethod("format_string", signature(format = "character", x = "Column"), column(jc) }) -#' from_unixtime +#' @details +#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a +#' string representing the timestamp of that moment in the current system time zone in the JVM in the +#' given format. See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' Customizing Formats} for available options. #' -#' Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string -#' representing the timestamp of that moment in the current system time zone in the given -#' format. +#' @rdname column_datetime_functions #' -#' @param x a Column of unix timestamp. -#' @param format the target format. See -#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ -#' Customizing Formats} for available options. -#' @param ... further arguments to be passed to or from other methods. -#' @family datetime_funcs -#' @rdname from_unixtime -#' @name from_unixtime -#' @aliases from_unixtime,Column-method +#' @aliases from_unixtime from_unixtime,Column-method #' @export #' @examples -#'\dontrun{ -#'from_unixtime(df$t) -#'from_unixtime(df$t, 'yyyy/MM/dd HH') -#'} +#' +#' \dontrun{ +#' tmp <- mutate(df, to_unix = unix_timestamp(df$time), +#' to_unix2 = unix_timestamp(df$time, 'yyyy-MM-dd HH'), +#' from_unix = from_unixtime(unix_timestamp(df$time)), +#' from_unix2 = from_unixtime(unix_timestamp(df$time), 'yyyy-MM-dd HH:mm')) +#' head(tmp)} #' @note from_unixtime since 1.5.0 setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { @@ -2864,14 +2531,13 @@ setMethod("from_unixtime", signature(x = "Column"), column(jc) }) -#' window -#' -#' Bucketize rows into one or more time windows given a timestamp specifying column. Window -#' starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window +#' @details +#' \code{window}: Bucketizes rows into one or more time windows given a timestamp specifying column. +#' Window starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window #' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in -#' the order of months are not supported. +#' the order of months are not supported. It returns an output column of struct called 'window' +#' by default with the nested columns 'start' and 'end' #' -#' @param x a time Column. Must be of TimestampType. #' @param windowDuration a string specifying the width of the window, e.g. '1 second', #' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', #' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. Note that @@ -2887,27 +2553,22 @@ setMethod("from_unixtime", signature(x = "Column"), #' window intervals. For example, in order to have hourly tumbling windows #' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide #' \code{startTime} as \code{"15 minutes"}. -#' @param ... further arguments to be passed to or from other methods. -#' @return An output column of struct called 'window' by default with the nested columns 'start' -#' and 'end'. -#' @family datetime_funcs -#' @rdname window -#' @name window -#' @aliases window,Column-method +#' @rdname column_datetime_functions +#' @aliases window window,Column-method #' @export #' @examples -#'\dontrun{ -#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, -#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... -#' window(df$time, "1 minute", "15 seconds", "10 seconds") #' -#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, -#' # 09:01:15-09:02:15... -#' window(df$time, "1 minute", startTime = "15 seconds") +#' \dontrun{ +#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, +#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... +#' window(df$time, "1 minute", "15 seconds", "10 seconds") +#' +#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, +#' # 09:01:15-09:02:15... +#' window(df$time, "1 minute", startTime = "15 seconds") #' -#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... -#' window(df$time, "30 seconds", "10 seconds") -#'} +#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... +#' window(df$time, "30 seconds", "10 seconds")} #' @note window since 2.0.0 setMethod("window", signature(x = "Column"), function(x, windowDuration, slideDuration = NULL, startTime = NULL) { @@ -2935,23 +2596,17 @@ setMethod("window", signature(x = "Column"), column(jc) }) -#' locate -#' -#' Locate the position of the first occurrence of substr. -#' +#' @details +#' \code{locate}: Locates the position of the first occurrence of substr. #' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. #' @param str a Column where matches are sought for each entry. #' @param pos start position of search. -#' @param ... further arguments to be passed to or from other methods. -#' @family string_funcs -#' @rdname locate -#' @aliases locate,character,Column-method -#' @name locate +#' @rdname column_string_functions +#' @aliases locate locate,character,Column-method #' @export -#' @examples \dontrun{locate('b', df$c, 1)} #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2961,19 +2616,14 @@ setMethod("locate", signature(substr = "character", str = "Column"), column(jc) }) -#' lpad -#' -#' Left-pad the string column with +#' @details +#' \code{lpad}: Left-padded with pad to a length of len. #' -#' @param x the string Column to be left-padded. #' @param len maximum length of each output result. #' @param pad a character string to be padded with. -#' @family string_funcs -#' @rdname lpad -#' @aliases lpad,Column,numeric,character-method -#' @name lpad +#' @rdname column_string_functions +#' @aliases lpad lpad,Column,numeric,character-method #' @export -#' @examples \dontrun{lpad(df$c, 6, '#')} #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2983,18 +2633,19 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), column(jc) }) -#' rand -#' -#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' @details +#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) samples #' from U[0.0, 1.0]. #' +#' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. -#' @family normal_funcs -#' @rdname rand -#' @name rand -#' @aliases rand,missing-method +#' @aliases rand rand,missing-method #' @export -#' @examples \dontrun{rand()} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, r1 = rand(), r2 = rand(10), r3 = randn(), r4 = randn(10)) +#' head(tmp)} #' @note rand since 1.5.0 setMethod("rand", signature(seed = "missing"), function(seed) { @@ -3002,8 +2653,7 @@ setMethod("rand", signature(seed = "missing"), column(jc) }) -#' @rdname rand -#' @name rand +#' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method #' @export #' @note rand(numeric) since 1.5.0 @@ -3013,18 +2663,13 @@ setMethod("rand", signature(seed = "numeric"), column(jc) }) -#' randn -#' -#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' @details +#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples from #' the standard normal distribution. #' -#' @param seed a random seed. Can be missing. -#' @family normal_funcs -#' @rdname randn -#' @name randn -#' @aliases randn,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases randn randn,missing-method #' @export -#' @examples \dontrun{randn()} #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -3032,8 +2677,7 @@ setMethod("randn", signature(seed = "missing"), column(jc) }) -#' @rdname randn -#' @name randn +#' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method #' @export #' @note randn(numeric) since 1.5.0 @@ -3043,20 +2687,27 @@ setMethod("randn", signature(seed = "numeric"), column(jc) }) -#' regexp_extract +#' @details +#' \code{regexp_extract}: Extracts a specific \code{idx} group identified by a Java regex, +#' from the specified string column. If the regex did not match, or the specified group did +#' not match, an empty string is returned. #' -#' Extract a specific \code{idx} group identified by a Java regex, from the specified string column. -#' If the regex did not match, or the specified group did not match, an empty string is returned. -#' -#' @param x a string Column. #' @param pattern a regular expression. #' @param idx a group index. -#' @family string_funcs -#' @rdname regexp_extract -#' @name regexp_extract -#' @aliases regexp_extract,Column,character,numeric-method +#' @rdname column_string_functions +#' @aliases regexp_extract regexp_extract,Column,character,numeric-method #' @export -#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = regexp_extract(df$Class, "(\\d+)\\w+", 1), +#' s2 = regexp_extract(df$Sex, "^(\\w)\\w+", 1), +#' s3 = regexp_replace(df$Class, "\\D+", ""), +#' s4 = substring_index(df$Sex, "a", 1), +#' s5 = substring_index(df$Sex, "a", -1), +#' s6 = translate(df$Sex, "ale", ""), +#' s7 = translate(df$Sex, "a", "-")) +#' head(tmp)} #' @note regexp_extract since 1.5.0 setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), @@ -3067,19 +2718,14 @@ setMethod("regexp_extract", column(jc) }) -#' regexp_replace +#' @details +#' \code{regexp_replace}: Replaces all substrings of the specified string value that +#' match regexp with rep. #' -#' Replace all substrings of the specified string value that match regexp with rep. -#' -#' @param x a string Column. -#' @param pattern a regular expression. #' @param replacement a character string that a matched \code{pattern} is replaced with. -#' @family string_funcs -#' @rdname regexp_replace -#' @name regexp_replace -#' @aliases regexp_replace,Column,character,character-method +#' @rdname column_string_functions +#' @aliases regexp_replace regexp_replace,Column,character,character-method #' @export -#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -3090,19 +2736,12 @@ setMethod("regexp_replace", column(jc) }) -#' rpad +#' @details +#' \code{rpad}: Right-padded with pad to a length of len. #' -#' Right-padded with pad to a length of len. -#' -#' @param x the string Column to be right-padded. -#' @param len maximum length of each output result. -#' @param pad a character string to be padded with. -#' @family string_funcs -#' @rdname rpad -#' @name rpad -#' @aliases rpad,Column,numeric,character-method +#' @rdname column_string_functions +#' @aliases rpad rpad,Column,numeric,character-method #' @export -#' @examples \dontrun{rpad(df$c, 6, '#')} #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -3112,28 +2751,20 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), column(jc) }) -#' substring_index -#' -#' Returns the substring from string str before count occurrences of the delimiter delim. -#' If count is positive, everything the left of the final delimiter (counting from left) is -#' returned. If count is negative, every to the right of the final delimiter (counting from the -#' right) is returned. substring_index performs a case-sensitive match when searching for delim. +#' @details +#' \code{substring_index}: Returns the substring from string str before count occurrences of +#' the delimiter delim. If count is positive, everything the left of the final delimiter +#' (counting from left) is returned. If count is negative, every to the right of the final +#' delimiter (counting from the right) is returned. substring_index performs a case-sensitive +#' match when searching for delim. #' -#' @param x a Column. #' @param delim a delimiter string. #' @param count number of occurrences of \code{delim} before the substring is returned. #' A positive number means counting from the left, while negative means #' counting from the right. -#' @family string_funcs -#' @rdname substring_index -#' @aliases substring_index,Column,character,numeric-method -#' @name substring_index +#' @rdname column_string_functions +#' @aliases substring_index substring_index,Column,character,numeric-method #' @export -#' @examples -#'\dontrun{ -#'substring_index(df$c, '.', 2) -#'substring_index(df$c, '.', -1) -#'} #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -3144,24 +2775,19 @@ setMethod("substring_index", column(jc) }) -#' translate -#' -#' Translate any character in the src by a character in replaceString. +#' @details +#' \code{translate}: Translates any character in the src by a character in replaceString. #' The characters in replaceString is corresponding to the characters in matchingString. #' The translate will happen when any character in the string matching with the character #' in the matchingString. #' -#' @param x a string Column. #' @param matchingString a source string where each character will be translated. #' @param replaceString a target string where each \code{matchingString} character will #' be replaced by the character in \code{replaceString} #' at the same location, if any. -#' @family string_funcs -#' @rdname translate -#' @name translate -#' @aliases translate,Column,character,character-method +#' @rdname column_string_functions +#' @aliases translate translate,Column,character,character-method #' @export -#' @examples \dontrun{translate(df$c, 'rnlt', '123')} #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -3171,21 +2797,12 @@ setMethod("translate", column(jc) }) -#' unix_timestamp -#' -#' Gets current Unix timestamp in seconds. +#' @details +#' \code{unix_timestamp}: Gets current Unix timestamp in seconds. #' -#' @family datetime_funcs -#' @rdname unix_timestamp -#' @name unix_timestamp -#' @aliases unix_timestamp,missing,missing-method +#' @rdname column_datetime_functions +#' @aliases unix_timestamp unix_timestamp,missing,missing-method #' @export -#' @examples -#'\dontrun{ -#'unix_timestamp() -#'unix_timestamp(df$t) -#'unix_timestamp(df$t, 'yyyy-MM-dd HH') -#'} #' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { @@ -3193,8 +2810,7 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), column(jc) }) -#' @rdname unix_timestamp -#' @name unix_timestamp +#' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,missing-method #' @export #' @note unix_timestamp(Column) since 1.5.0 @@ -3204,12 +2820,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), column(jc) }) -#' @param x a Column of date, in string, date or timestamp type. -#' @param format the target format. See -#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ -#' Customizing Formats} for available options. -#' @rdname unix_timestamp -#' @name unix_timestamp +#' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,character-method #' @export #' @note unix_timestamp(Column, character) since 1.5.0 @@ -3218,20 +2829,26 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) column(jc) }) -#' when -#' -#' Evaluates a list of conditions and returns one of multiple possible result expressions. + +#' @details +#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result expressions. #' For unmatched expressions null is returned. #' +#' @rdname column_nonaggregate_functions #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. -#' @family normal_funcs -#' @rdname when -#' @name when -#' @aliases when,Column-method -#' @seealso \link{ifelse} +#' @aliases when when,Column-method #' @export -#' @examples \dontrun{when(df$age == 2, df$age + 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, mpg_na = otherwise(when(df$mpg > 20, df$mpg), lit(NaN)), +#' mpg2 = ifelse(df$mpg > 20 & df$am > 0, 0, 1), +#' mpg3 = ifelse(df$mpg > 20, df$mpg, 20.0)) +#' head(tmp) +#' tmp <- mutate(tmp, ind_na1 = is.nan(tmp$mpg_na), ind_na2 = isnan(tmp$mpg_na)) +#' head(select(tmp, coalesce(tmp$mpg_na, tmp$mpg))) +#' head(select(tmp, nanvl(tmp$mpg_na, tmp$hp)))} #' @note when since 1.5.0 setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { @@ -3241,24 +2858,16 @@ setMethod("when", signature(condition = "Column", value = "ANY"), column(jc) }) -#' ifelse -#' -#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. +#' @details +#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. #' Otherwise \code{no} is returned for unmatched conditions. #' +#' @rdname column_nonaggregate_functions #' @param test a Column expression that describes the condition. #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. -#' @family normal_funcs -#' @rdname ifelse -#' @name ifelse -#' @aliases ifelse,Column-method -#' @seealso \link{when} -#' @export -#' @examples \dontrun{ -#' ifelse(df$a > 1 & df$b > 2, 0, 1) -#' ifelse(df$a > 1, df$a, 1) -#' } +#' @aliases ifelse ifelse,Column-method +#' @export #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -3275,26 +2884,16 @@ setMethod("ifelse", ###################### Window functions###################### -#' cume_dist -#' -#' Window function: returns the cumulative distribution of values within a window partition, -#' i.e. the fraction of rows that are below the current row. -#' -#' N = total number of rows in the partition -#' cume_dist(x) = number of values before (and including) x / N -#' +#' @details +#' \code{cume_dist}: Returns the cumulative distribution of values within a window partition, +#' i.e. the fraction of rows that are below the current row: +#' (number of values before and including x) / (total number of rows in the partition). #' This is equivalent to the \code{CUME_DIST} function in SQL. +#' The method should be used with no argument. #' -#' @rdname cume_dist -#' @name cume_dist -#' @family window_funcs -#' @aliases cume_dist,missing-method +#' @rdname column_window_functions +#' @aliases cume_dist cume_dist,missing-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(cume_dist(), ws), df$hp, df$am) -#' } #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -3303,27 +2902,19 @@ setMethod("cume_dist", column(jc) }) -#' dense_rank -#' -#' Window function: returns the rank of rows within a window partition, without any gaps. +#' @details +#' \code{dense_rank}: Returns the rank of rows within a window partition, without any gaps. #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. Rank would give me sequential numbers, making #' the person that came in third place (after the ties) would register as coming in fifth. -#' #' This is equivalent to the \code{DENSE_RANK} function in SQL. +#' The method should be used with no argument. #' -#' @rdname dense_rank -#' @name dense_rank -#' @family window_funcs -#' @aliases dense_rank,missing-method +#' @rdname column_window_functions +#' @aliases dense_rank dense_rank,missing-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(dense_rank(), ws), df$hp, df$am) -#' } #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -3332,33 +2923,15 @@ setMethod("dense_rank", column(jc) }) -#' lag -#' -#' Window function: returns the value that is \code{offset} rows before the current row, and +#' @details +#' \code{lag}: Returns the value that is \code{offset} rows before the current row, and #' \code{defaultValue} if there is less than \code{offset} rows before the current row. For example, #' an \code{offset} of one will return the previous row at any given point in the window partition. -#' #' This is equivalent to the \code{LAG} function in SQL. #' -#' @param x the column as a character string or a Column to compute on. -#' @param offset the number of rows back from the current row from which to obtain a value. -#' If not specified, the default is 1. -#' @param defaultValue (optional) default to use when the offset row does not exist. -#' @param ... further arguments to be passed to or from other methods. -#' @rdname lag -#' @name lag -#' @aliases lag,characterOrColumn-method -#' @family window_funcs +#' @rdname column_window_functions +#' @aliases lag lag,characterOrColumn-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Lag mpg values by 1 row on the partition-and-ordered table -#' out <- select(df, over(lag(df$mpg), ws), df$mpg, df$hp, df$am) -#' } #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -3374,34 +2947,16 @@ setMethod("lag", column(jc) }) -#' lead -#' -#' Window function: returns the value that is \code{offset} rows after the current row, and +#' @details +#' \code{lead}: Returns the value that is \code{offset} rows after the current row, and #' \code{defaultValue} if there is less than \code{offset} rows after the current row. #' For example, an \code{offset} of one will return the next row at any given point #' in the window partition. -#' #' This is equivalent to the \code{LEAD} function in SQL. #' -#' @param x the column as a character string or a Column to compute on. -#' @param offset the number of rows after the current row from which to obtain a value. -#' If not specified, the default is 1. -#' @param defaultValue (optional) default to use when the offset row does not exist. -#' -#' @rdname lead -#' @name lead -#' @family window_funcs -#' @aliases lead,characterOrColumn,numeric-method +#' @rdname column_window_functions +#' @aliases lead lead,characterOrColumn,numeric-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Lead mpg values by 1 row on the partition-and-ordered table -#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am) -#' } #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -3417,30 +2972,15 @@ setMethod("lead", column(jc) }) -#' ntile -#' -#' Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window +#' @details +#' \code{ntile}: Returns the ntile group id (from 1 to n inclusive) in an ordered window #' partition. For example, if n is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. -#' #' This is equivalent to the \code{NTILE} function in SQL. #' -#' @param x Number of ntile groups -#' -#' @rdname ntile -#' @name ntile -#' @aliases ntile,numeric-method -#' @family window_funcs +#' @rdname column_window_functions +#' @aliases ntile ntile,numeric-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Get ntile group id (1-4) for hp -#' out <- select(df, over(ntile(4), ws), df$hp, df$am) -#' } #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3449,26 +2989,15 @@ setMethod("ntile", column(jc) }) -#' percent_rank -#' -#' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. -#' -#' This is computed by: -#' -#' (rank of row in its partition - 1) / (number of rows in the partition - 1) +#' @details +#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition. +#' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1). +#' This is equivalent to the \code{PERCENT_RANK} function in SQL. +#' The method should be used with no argument. #' -#' This is equivalent to the PERCENT_RANK function in SQL. -#' -#' @rdname percent_rank -#' @name percent_rank -#' @family window_funcs -#' @aliases percent_rank,missing-method +#' @rdname column_window_functions +#' @aliases percent_rank percent_rank,missing-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(percent_rank(), ws), df$hp, df$am) -#' } #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3477,28 +3006,19 @@ setMethod("percent_rank", column(jc) }) -#' rank -#' -#' Window function: returns the rank of rows within a window partition. -#' +#' @details +#' \code{rank}: Returns the rank of rows within a window partition. #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. Rank would give me sequential numbers, making #' the person that came in third place (after the ties) would register as coming in fifth. +#' This is equivalent to the \code{RANK} function in SQL. +#' The method should be used with no argument. #' -#' This is equivalent to the RANK function in SQL. -#' -#' @rdname rank -#' @name rank -#' @family window_funcs -#' @aliases rank,missing-method +#' @rdname column_window_functions +#' @aliases rank rank,missing-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(rank(), ws), df$hp, df$am) -#' } #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3507,11 +3027,7 @@ setMethod("rank", column(jc) }) -# Expose rank() in the R base package -#' @param x a numeric, complex, character or logical vector. -#' @param ... additional argument(s) passed to the method. -#' @name rank -#' @rdname rank +#' @rdname column_window_functions #' @aliases rank,ANY-method #' @export setMethod("rank", @@ -3520,22 +3036,14 @@ setMethod("rank", base::rank(x, ...) }) -#' row_number -#' -#' Window function: returns a sequential number starting at 1 within a window partition. -#' -#' This is equivalent to the ROW_NUMBER function in SQL. +#' @details +#' \code{row_number}: Returns a sequential number starting at 1 within a window partition. +#' This is equivalent to the \code{ROW_NUMBER} function in SQL. +#' The method should be used with no argument. #' -#' @rdname row_number -#' @name row_number -#' @aliases row_number,missing-method -#' @family window_funcs +#' @rdname column_window_functions +#' @aliases row_number row_number,missing-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(row_number(), ws), df$hp, df$am) -#' } #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), @@ -3546,18 +3054,14 @@ setMethod("row_number", ###################### Collection functions###################### -#' array_contains -#' -#' Returns null if the array is null, true if the array contains the value, and false otherwise. +#' @details +#' \code{array_contains}: Returns null if the array is null, true if the array contains +#' the value, and false otherwise. #' -#' @param x A Column -#' @param value A value to be checked if contained in the column -#' @rdname array_contains -#' @aliases array_contains,Column-method -#' @name array_contains -#' @family collection_funcs +#' @param value a value to be checked if contained in the column +#' @rdname column_collection_functions +#' @aliases array_contains array_contains,Column-method #' @export -#' @examples \dontrun{array_contains(df$c, 1)} #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3566,18 +3070,40 @@ setMethod("array_contains", column(jc) }) -#' explode +#' @details +#' \code{map_keys}: Returns an unordered array containing the keys of the map. #' -#' Creates a new row for each element in the given array or map column. +#' @rdname column_collection_functions +#' @aliases map_keys map_keys,Column-method +#' @export +#' @note map_keys since 2.3.0 +setMethod("map_keys", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_keys", x@jc) + column(jc) + }) + +#' @details +#' \code{map_values}: Returns an unordered array containing the values of the map. #' -#' @param x Column to compute on +#' @rdname column_collection_functions +#' @aliases map_values map_values,Column-method +#' @export +#' @note map_values since 2.3.0 +setMethod("map_values", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_values", x@jc) + column(jc) + }) + +#' @details +#' \code{explode}: Creates a new row for each element in the given array or map column. #' -#' @rdname explode -#' @name explode -#' @family collection_funcs -#' @aliases explode,Column-method +#' @rdname column_collection_functions +#' @aliases explode explode,Column-method #' @export -#' @examples \dontrun{explode(df$c)} #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3586,18 +3112,12 @@ setMethod("explode", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @param x Column to compute on +#' @details +#' \code{size}: Returns length of array or map. #' -#' @rdname size -#' @name size -#' @aliases size,Column-method -#' @family collection_funcs +#' @rdname column_collection_functions +#' @aliases size size,Column-method #' @export -#' @examples \dontrun{size(df$c)} #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3606,25 +3126,16 @@ setMethod("size", column(jc) }) -#' sort_array -#' -#' Sorts the input array in ascending or descending order according +#' @details +#' \code{sort_array}: Sorts the input array in ascending or descending order according #' to the natural ordering of the array elements. #' -#' @param x A Column to sort -#' @param asc A logical flag indicating the sorting order. +#' @rdname column_collection_functions +#' @param asc a logical flag indicating the sorting order. #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. -#' @rdname sort_array -#' @name sort_array -#' @aliases sort_array,Column-method -#' @family collection_funcs +#' @aliases sort_array sort_array,Column-method #' @export -#' @examples -#' \dontrun{ -#' sort_array(df$c) -#' sort_array(df$c, FALSE) -#' } #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3633,18 +3144,13 @@ setMethod("sort_array", column(jc) }) -#' posexplode -#' -#' Creates a new row for each element with position in the given array or map column. -#' -#' @param x Column to compute on +#' @details +#' \code{posexplode}: Creates a new row for each element with position in the given array +#' or map column. #' -#' @rdname posexplode -#' @name posexplode -#' @family collection_funcs -#' @aliases posexplode,Column-method +#' @rdname column_collection_functions +#' @aliases posexplode posexplode,Column-method #' @export -#' @examples \dontrun{posexplode(df$c)} #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3653,19 +3159,12 @@ setMethod("posexplode", column(jc) }) -#' create_array +#' @details +#' \code{create_array}: Creates a new array column. The input columns must all have the same data type. #' -#' Creates a new array column. The input columns must all have the same data type. -#' -#' @param x Column to compute on -#' @param ... additional Column(s). -#' -#' @family normal_funcs -#' @rdname create_array -#' @name create_array -#' @aliases create_array,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases create_array create_array,Column-method #' @export -#' @examples \dontrun{create_array(df$x, df$y, df$z)} #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3678,22 +3177,15 @@ setMethod("create_array", column(jc) }) -#' create_map -#' -#' Creates a new map column. The input columns must be grouped as key-value pairs, +#' @details +#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value pairs, #' e.g. (key1, value1, key2, value2, ...). #' The key columns must all have the same data type, and can't be null. #' The value columns must all have the same data type. #' -#' @param x Column to compute on -#' @param ... additional Column(s). -#' -#' @family normal_funcs -#' @rdname create_map -#' @name create_map -#' @aliases create_map,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases create_map create_map,Column-method #' @export -#' @examples \dontrun{create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))} #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3706,18 +3198,18 @@ setMethod("create_map", column(jc) }) -#' collect_list -#' -#' Creates a list of objects with duplicates. -#' -#' @param x Column to compute on +#' @details +#' \code{collect_list}: Creates a list of objects with duplicates. #' -#' @rdname collect_list -#' @name collect_list -#' @family agg_funcs -#' @aliases collect_list,Column-method +#' @rdname column_aggregate_functions +#' @aliases collect_list collect_list,Column-method #' @export -#' @examples \dontrun{collect_list(df$x)} +#' @examples +#' +#' \dontrun{ +#' df2 = df[df$mpg > 20, ] +#' collect(select(df2, collect_list(df2$gear))) +#' collect(select(df2, collect_set(df2$gear)))} #' @note collect_list since 2.3.0 setMethod("collect_list", signature(x = "Column"), @@ -3726,18 +3218,12 @@ setMethod("collect_list", column(jc) }) -#' collect_set +#' @details +#' \code{collect_set}: Creates a list of objects with duplicate elements eliminated. #' -#' Creates a list of objects with duplicate elements eliminated. -#' -#' @param x Column to compute on -#' -#' @rdname collect_set -#' @name collect_set -#' @family agg_funcs -#' @aliases collect_set,Column-method +#' @rdname column_aggregate_functions +#' @aliases collect_set collect_set,Column-method #' @export -#' @examples \dontrun{collect_set(df$x)} #' @note collect_set since 2.3.0 setMethod("collect_set", signature(x = "Column"), @@ -3746,27 +3232,20 @@ setMethod("collect_set", column(jc) }) -#' split_string -#' -#' Splits string on regular expression. +#' @details +#' \code{split_string}: Splits string on regular expression. +#' Equivalent to \code{split} SQL function. #' -#' Equivalent to \code{split} SQL function -#' -#' @param x Column to compute on -#' @param pattern Java regular expression -#' -#' @rdname split_string -#' @family string_funcs -#' @aliases split_string,Column-method +#' @rdname column_string_functions +#' @aliases split_string split_string,Column-method #' @export -#' @examples \dontrun{ -#' df <- read.text("README.md") -#' -#' head(select(df, split_string(df$value, "\\s+"))) +#' @examples #' +#' \dontrun{ +#' head(select(df, split_string(df$Sex, "a"))) +#' head(select(df, split_string(df$Class, "\\d"))) #' # This is equivalent to the following SQL expression -#' head(selectExpr(df, "split(value, '\\\\s+')")) -#' } +#' head(selectExpr(df, "split(Class, '\\\\d')"))} #' @note split_string 2.3.0 setMethod("split_string", signature(x = "Column", pattern = "character"), @@ -3775,27 +3254,20 @@ setMethod("split_string", column(jc) }) -#' repeat_string -#' -#' Repeats string n times. +#' @details +#' \code{repeat_string}: Repeats string n times. +#' Equivalent to \code{repeat} SQL function. #' -#' Equivalent to \code{repeat} SQL function -#' -#' @param x Column to compute on -#' @param n Number of repetitions -#' -#' @rdname repeat_string -#' @family string_funcs -#' @aliases repeat_string,Column-method +#' @param n number of repetitions. +#' @rdname column_string_functions +#' @aliases repeat_string repeat_string,Column-method #' @export -#' @examples \dontrun{ -#' df <- read.text("README.md") -#' -#' first(select(df, repeat_string(df$value, 3))) +#' @examples #' +#' \dontrun{ +#' head(select(df, repeat_string(df$Class, 3))) #' # This is equivalent to the following SQL expression -#' first(selectExpr(df, "repeat(value, 3)")) -#' } +#' head(selectExpr(df, "repeat(Class, 3)"))} #' @note repeat_string since 2.3.0 setMethod("repeat_string", signature(x = "Column", n = "numeric"), @@ -3804,26 +3276,24 @@ setMethod("repeat_string", column(jc) }) -#' explode_outer -#' -#' Creates a new row for each element in the given array or map column. +#' @details +#' \code{explode}: Creates a new row for each element in the given array or map column. #' Unlike \code{explode}, if the array/map is \code{null} or empty #' then \code{null} is produced. #' -#' @param x Column to compute on #' -#' @rdname explode_outer -#' @name explode_outer -#' @family collection_funcs -#' @aliases explode_outer,Column-method +#' @rdname column_collection_functions +#' @aliases explode_outer explode_outer,Column-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(data.frame( +#' @examples +#' +#' \dontrun{ +#' df2 <- createDataFrame(data.frame( #' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") #' )) #' -#' head(select(df, df$id, explode_outer(split_string(df$text, ",")))) -#' } +#' head(select(df2, df2$id, explode_outer(split_string(df2$text, ",")))) +#' head(select(df2, df2$id, posexplode_outer(split_string(df2$text, ","))))} #' @note explode_outer since 2.3.0 setMethod("explode_outer", signature(x = "Column"), @@ -3832,26 +3302,14 @@ setMethod("explode_outer", column(jc) }) -#' posexplode_outer -#' -#' Creates a new row for each element with position in the given array or map column. -#' Unlike \code{posexplode}, if the array/map is \code{null} or empty +#' @details +#' \code{posexplode_outer}: Creates a new row for each element with position in the given +#' array or map column. Unlike \code{posexplode}, if the array/map is \code{null} or empty #' then the row (\code{null}, \code{null}) is produced. #' -#' @param x Column to compute on -#' -#' @rdname posexplode_outer -#' @name posexplode_outer -#' @family collection_funcs -#' @aliases posexplode_outer,Column-method +#' @rdname column_collection_functions +#' @aliases posexplode_outer posexplode_outer,Column-method #' @export -#' @examples \dontrun{ -#' df <- createDataFrame(data.frame( -#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") -#' )) -#' -#' head(select(df, df$id, posexplode_outer(split_string(df$text, ",")))) -#' } #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), @@ -3871,8 +3329,10 @@ setMethod("posexplode_outer", #' @rdname not #' @name not #' @aliases not,Column-method +#' @family non-aggregate functions #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(data.frame( #' is_true = c(TRUE, FALSE, NA), #' flag = c(1, 0, 1) @@ -3890,3 +3350,111 @@ setMethod("not", jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) column(jc) }) + +#' @details +#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or not, +#' returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} in SQL +#' and \code{grouping} function in Scala. +#' +#' @rdname column_aggregate_functions +#' @aliases grouping_bit grouping_bit,Column-method +#' @export +#' @examples +#' +#' \dontrun{ +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' )} +#' @note grouping_bit since 2.3.0 +setMethod("grouping_bit", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "grouping", x@jc) + column(jc) + }) + +#' @details +#' \code{grouping_id}: Returns the level of grouping. +#' Equals to \code{ +#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) +#' }. +#' +#' @rdname column_aggregate_functions +#' @aliases grouping_id grouping_id,Column-method +#' @export +#' @examples +#' +#' \dontrun{ +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' )} +#' @note grouping_id since 2.3.0 +setMethod("grouping_id", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols) + column(jc) + }) + +#' @details +#' \code{input_file_name}: Creates a string column with the input file name for a given row. +#' The method should be used with no argument. +#' +#' @rdname column_nonaggregate_functions +#' @aliases input_file_name input_file_name,missing-method +#' @export +#' @examples +#' +#' \dontrun{ +#' tmp <- read.text("README.md") +#' head(select(tmp, input_file_name()))} +#' @note input_file_name since 2.3.0 +setMethod("input_file_name", signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "input_file_name") + column(jc) + }) + +#' @details +#' \code{trunc}: Returns date truncated to the unit specified by the format. +#' +#' @rdname column_datetime_functions +#' @aliases trunc trunc,Column-method +#' @export +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, trunc(df$time, "year"), trunc(df$time, "yy"), +#' trunc(df$time, "month"), trunc(df$time, "mon")))} +#' @note trunc since 2.3.0 +setMethod("trunc", + signature(x = "Column"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "trunc", + x@jc, as.character(format)) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ef36765a7a72..0fe8f0453b06 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -387,6 +387,17 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @export setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +#' alias +#' +#' Returns a new SparkDataFrame or a Column with an alias set. Equivalent to SQL "AS" keyword. +#' +#' @name alias +#' @rdname alias +#' @param object x a SparkDataFrame or a Column +#' @param data new name to use +#' @return a SparkDataFrame or a Column +NULL + #' @rdname arrange #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) @@ -411,9 +422,8 @@ setGeneric("cache", function(x) { standardGeneric("cache") }) setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce -#' @param x a Column or a SparkDataFrame. -#' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally -#' provided. +#' @param x a SparkDataFrame. +#' @param ... additional argument(s). #' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) @@ -468,7 +478,7 @@ setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) -#' @rdname covar_pop +#' @rdname cov #' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) @@ -511,7 +521,7 @@ setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") # @export setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) -#' @rdname summary +#' @rdname describe #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -576,6 +586,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) +#' @rdname hint +#' @export +setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) + #' @rdname insertInto #' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) @@ -631,7 +645,7 @@ setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample #' @export setGeneric("sample", - function(x, withReplacement, fraction, seed) { + function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample") }) @@ -642,7 +656,7 @@ setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) + function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname sampleBy #' @export @@ -755,6 +769,10 @@ setGeneric("union", function(x, y) { standardGeneric("union") }) #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) +#' @rdname unionByName +#' @export +setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) + #' @rdname unpersist #' @export setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) @@ -784,6 +802,10 @@ setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.d #' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) +#' @rdname broadcast +#' @export +setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) + ###################### Column Methods ########################## #' @rdname columnfunctions @@ -844,8 +866,9 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) -#' @rdname when +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise @@ -884,20 +907,24 @@ setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy" ###################### Expression Function Methods ########################## -#' @rdname add_months +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) -#' @rdname approxCountDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) -#' @rdname array_contains +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) -#' @rdname ascii +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. @@ -906,496 +933,633 @@ setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) -#' @rdname base64 +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) -#' @rdname bin +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) -#' @rdname bitwiseNOT +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) -#' @rdname bround +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) -#' @rdname cbrt +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) -#' @rdname ceil +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) -#' @rdname collect_list +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) -#' @rdname collect_set +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column #' @export setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname concat +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) -#' @rdname concat_ws +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) -#' @rdname conv +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) -#' @rdname countDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) -#' @rdname crc32 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname create_array +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) -#' @rdname create_map +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) -#' @rdname hash +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) -#' @param x empty. Should be used with no argument. -#' @rdname cume_dist +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) -#' @rdname datediff +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) -#' @rdname date_add +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) -#' @rdname date_format +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) -#' @rdname date_sub +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) -#' @rdname dayofmonth +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) -#' @rdname dayofyear +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname decode +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) -#' @param x empty. Should be used with no argument. -#' @rdname dense_rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) -#' @rdname encode +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) -#' @rdname explode +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) -#' @rdname explode_outer +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) -#' @rdname expr +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) -#' @rdname from_utc_timestamp +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) -#' @rdname format_number +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) -#' @rdname format_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) -#' @rdname from_json +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) -#' @rdname from_unixtime +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) -#' @rdname greatest +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) -#' @rdname hex +#' @rdname column_aggregate_functions #' @export +#' @name NULL +setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) + +#' @rdname column_aggregate_functions +#' @export +#' @name NULL +setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) + +#' @rdname column_math_functions +#' @export +#' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) -#' @rdname hour +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) -#' @rdname hypot +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) -#' @rdname initcap +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) -#' @rdname instr +#' @rdname column_nonaggregate_functions +#' @export +#' @name NULL +setGeneric("input_file_name", + function(x = "missing") { standardGeneric("input_file_name") }) + +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname is.nan +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) -#' @rdname kurtosis +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) -#' @rdname lag +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export setGeneric("last", function(x, ...) { standardGeneric("last") }) -#' @rdname last_day +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) -#' @rdname lead +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) -#' @rdname least +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) -#' @rdname levenshtein +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) -#' @rdname lit +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) -#' @rdname locate +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) -#' @rdname lower +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) -#' @rdname lpad +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) -#' @rdname ltrim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) -#' @rdname md5 +#' @rdname column_collection_functions #' @export +#' @name NULL +setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) + +#' @rdname column_collection_functions +#' @export +#' @name NULL +setGeneric("map_values", function(x) { standardGeneric("map_values") }) + +#' @rdname column_misc_functions +#' @export +#' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) -#' @rdname minute +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) -#' @param x empty. Should be used with no argument. -#' @rdname monotonically_increasing_id +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) -#' @rdname month +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("month", function(x) { standardGeneric("month") }) -#' @rdname months_between +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) #' @rdname count #' @export setGeneric("n", function(x) { standardGeneric("n") }) -#' @rdname nanvl +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) -#' @rdname negate +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not #' @export setGeneric("not", function(x) { standardGeneric("not") }) -#' @rdname next_day +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) -#' @rdname ntile +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) -#' @rdname countDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @param x empty. Should be used with no argument. -#' @rdname percent_rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) -#' @rdname pmod +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) -#' @rdname posexplode +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) -#' @rdname posexplode_outer +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) -#' @rdname quarter +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) -#' @rdname rand +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) -#' @rdname randn +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) -#' @rdname rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) -#' @rdname regexp_extract +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) -#' @rdname regexp_replace +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) -#' @rdname repeat_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname reverse +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) -#' @rdname rint +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) -#' @param x empty. Should be used with no argument. -#' @rdname row_number +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) -#' @rdname rpad +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) -#' @rdname rtrim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) -#' @rdname sd +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) -#' @rdname second +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) -#' @rdname sha1 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) -#' @rdname sha2 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) -#' @rdname shiftLeft +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) -#' @rdname shiftRight +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) -#' @rdname shiftRightUnsigned +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) -#' @rdname sign +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) -#' @rdname size +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) -#' @rdname skewness +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) -#' @rdname sort_array +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) -#' @rdname split_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) -#' @rdname soundex +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) -#' @param x empty. Should be used with no argument. -#' @rdname spark_partition_id +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) -#' @rdname sd +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev", function(x) { standardGeneric("stddev") }) -#' @rdname stddev_pop +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) -#' @rdname stddev_samp +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) -#' @rdname struct +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) -#' @rdname substring_index +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) -#' @rdname sumDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname toDegrees +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname toRadians +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname to_date +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) -#' @rdname to_json +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) -#' @rdname to_timestamp +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) -#' @rdname to_utc_timestamp +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) -#' @rdname translate +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) -#' @rdname trim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("trim", function(x) { standardGeneric("trim") }) -#' @rdname unbase64 +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) -#' @rdname unhex +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) -#' @rdname unix_timestamp +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) -#' @rdname upper +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) -#' @rdname var +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) -#' @rdname var +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("variance", function(x) { standardGeneric("variance") }) -#' @rdname var_pop +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) -#' @rdname var_samp +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) -#' @rdname weekofyear +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) -#' @rdname window +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) -#' @rdname year +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("year", function(x) { standardGeneric("year") }) @@ -1473,6 +1637,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) +#' @rdname spark.decisionTree +#' @export +setGeneric("spark.decisionTree", + function(data, formula, ...) { standardGeneric("spark.decisionTree") }) + #' @rdname spark.randomForest #' @export setGeneric("spark.randomForest", diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 17f5283abead..0a7be0e99397 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -233,6 +233,9 @@ setMethod("gapplyCollect", }) gapplyInternal <- function(x, func, schema) { + if (is.character(schema)) { + schema <- structType(schema) + } packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) broadcastArr <- lapply(ls(.broadcastNames), diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 4ca7aa664e02..492dee68e164 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -267,10 +267,14 @@ hadoopVersionName <- function(hadoopVersion) { # The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and # adapt to Spark context sparkCachePath <- function() { - if (.Platform$OS.type == "windows") { + if (is_windows()) { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) if (is.na(winAppPath)) { - stop(paste("%LOCALAPPDATA% not found.", + message("%LOCALAPPDATA% not found. Falling back to %USERPROFILE%.") + winAppPath <- Sys.getenv("USERPROFILE", unset = NA) + } + if (is.na(winAppPath)) { + stop(paste("%LOCALAPPDATA% and %USERPROFILE% not found.", "Please define the environment variable", "or restart and enter an installation path in localDir.")) } else { diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 4db9cc30fb0c..15af8298ba48 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -46,26 +46,34 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj" #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) -#' linear SVM Model +#' Linear SVM Model #' -#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package +#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package. +#' Currently only supports binary classification model with linear kernel. #' Users can print, make predictions on the produced model and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam The regularization parameter. +#' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. #' @param standardization Whether to standardize the training features before fitting the model. The coefficients #' of models will be always returned on the original scale, so it will be transparent for #' users. Note that with/without standardization, the models should be always converged #' to the same solution when no regularization is applied. -#' @param threshold The threshold in binary classification, in range [0, 1]. +#' @param threshold The threshold in binary classification applied to the linear model prediction. +#' This threshold can be any real number, where Inf will make all predictions 0.0 +#' and -Inf will make all predictions 1.0. #' @param weightCol The weight column name. #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.svmLinear} returns a fitted linear SVM model. #' @rdname spark.svmLinear @@ -95,7 +103,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @note spark.svmLinear since 2.2.0 setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE, - threshold = 0.0, weightCol = NULL, aggregationDepth = 2) { + threshold = 0.0, weightCol = NULL, aggregationDepth = 2, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (!is.null(weightCol) && weightCol == "") { @@ -104,17 +113,19 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu weightCol <- as.character(weightCol) } + handleInvalid <- match.arg(handleInvalid) + jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.integer(maxIter), as.numeric(tol), as.logical(standardization), as.numeric(threshold), - weightCol, as.integer(aggregationDepth)) + weightCol, as.integer(aggregationDepth), handleInvalid) new("LinearSVCModel", jobj = jobj) }) -# Predicted values based on an LinearSVCModel model +# Predicted values based on a LinearSVCModel model #' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on an LinearSVCModel. +#' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method #' @export @@ -124,13 +135,12 @@ setMethod("predict", signature(object = "LinearSVCModel"), predict_internal(object, newData) }) -# Get the summary of an LinearSVCModel +# Get the summary of a LinearSVCModel -#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}. +#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}. #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list includes \code{coefficients} (coefficients of the fitted model), -#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes), -#' \code{numFeatures} (number of features). +#' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method #' @export @@ -138,22 +148,14 @@ setMethod("predict", signature(object = "LinearSVCModel"), setMethod("summary", signature(object = "LinearSVCModel"), function(object) { jobj <- object@jobj - features <- callJMethod(jobj, "features") - labels <- callJMethod(jobj, "labels") - coefficients <- callJMethod(jobj, "coefficients") - nCol <- length(coefficients) / length(features) - coefficients <- matrix(unlist(coefficients), ncol = nCol) - intercept <- callJMethod(jobj, "intercept") + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) numClasses <- callJMethod(jobj, "numClasses") numFeatures <- callJMethod(jobj, "numFeatures") - if (nCol == 1) { - colnames(coefficients) <- c("Estimate") - } else { - colnames(coefficients) <- unlist(labels) - } - rownames(coefficients) <- unlist(features) - list(coefficients = coefficients, intercept = intercept, - numClasses = numClasses, numFeatures = numFeatures) + list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures) }) # Save fitted LinearSVCModel to the input path @@ -210,6 +212,25 @@ function(object, path, overwrite = FALSE) { #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound constrained optimization. +#' The bound matrix must be compatible with the shape (1, number of features) for binomial +#' regression, or (number of classes, number of features) for multinomial regression. +#' It is a R matrix. +#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound constrained optimization. +#' The bound matrix must be compatible with the shape (1, number of features) for binomial +#' regression, or (number of classes, number of features) for multinomial regression. +#' It is a R matrix. +#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained optimization. +#' The bounds vector size must be equal to 1 for binomial regression, or the number +#' of classes for multinomial regression. +#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. +#' The bound vector size must be equal to 1 for binomial regression, or the number +#' of classes for multinomial regression. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -247,8 +268,13 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, + lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL, + lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") + row <- 0 + col <- 0 if (!is.null(weightCol) && weightCol == "") { weightCol <- NULL @@ -256,12 +282,54 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") weightCol <- as.character(weightCol) } + if (!is.null(lowerBoundsOnIntercepts)) { + lowerBoundsOnIntercepts <- as.array(lowerBoundsOnIntercepts) + } + + if (!is.null(upperBoundsOnIntercepts)) { + upperBoundsOnIntercepts <- as.array(upperBoundsOnIntercepts) + } + + if (!is.null(lowerBoundsOnCoefficients)) { + if (class(lowerBoundsOnCoefficients) != "matrix") { + stop("lowerBoundsOnCoefficients must be a matrix.") + } + row <- nrow(lowerBoundsOnCoefficients) + col <- ncol(lowerBoundsOnCoefficients) + lowerBoundsOnCoefficients <- as.array(as.vector(lowerBoundsOnCoefficients)) + } + + if (!is.null(upperBoundsOnCoefficients)) { + if (class(upperBoundsOnCoefficients) != "matrix") { + stop("upperBoundsOnCoefficients must be a matrix.") + } + + if (!is.null(lowerBoundsOnCoefficients) && (row != nrow(upperBoundsOnCoefficients) + || col != ncol(upperBoundsOnCoefficients))) { + stop(paste0("dimension of upperBoundsOnCoefficients ", + "is not the same as lowerBoundsOnCoefficients", sep = "")) + } + + if (is.null(lowerBoundsOnCoefficients)) { + row <- nrow(upperBoundsOnCoefficients) + col <- ncol(upperBoundsOnCoefficients) + } + + upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients)) + } + + handleInvalid <- match.arg(handleInvalid) + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - weightCol, as.integer(aggregationDepth)) + weightCol, as.integer(aggregationDepth), + as.integer(row), as.integer(col), + lowerBoundsOnCoefficients, upperBoundsOnCoefficients, + lowerBoundsOnIntercepts, upperBoundsOnIntercepts, + handleInvalid) new("LogisticRegressionModel", jobj = jobj) }) @@ -343,7 +411,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @param stepSize stepSize parameter. #' @param seed seed parameter for weights initialization. #' @param initialWeights initialWeights parameter for weights initialization, it should be a -#' numeric vector. +#' numeric vector. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -375,7 +448,8 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @note spark.mlp since 2.1.0 setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, - tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { + tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { stop ("layers must be a integer vector with length > 1.") @@ -390,10 +464,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), if (!is.null(initialWeights)) { initialWeights <- as.array(as.numeric(na.omit(initialWeights))) } + handleInvalid <- match.arg(handleInvalid) jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), seed, initialWeights) + as.numeric(stepSize), seed, initialWeights, handleInvalid) new("MultilayerPerceptronClassificationModel", jobj = jobj) }) @@ -463,6 +538,11 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param smoothing smoothing parameter. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. #' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes @@ -492,10 +572,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' } #' @note spark.naiveBayes since 2.0.0 setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, smoothing = 1.0) { + function(data, formula, smoothing = 1.0, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") + handleInvalid <- match.arg(handleInvalid) jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", - formula, data@sdf, smoothing) + formula, data@sdf, smoothing, handleInvalid) new("NaiveBayesModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index d59c890f3e5f..ebaeae970218 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -70,6 +70,14 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' the relationship between the variance and mean of the distribution. Only #' applicable to the Tweedie family. #' @param link.power the index in the power link function. Only applicable to the Tweedie family. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets +#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -79,7 +87,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' t <- as.data.frame(Titanic) +#' t <- as.data.frame(Titanic, stringsAsFactors = FALSE) #' df <- createDataFrame(t) #' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") #' summary(model) @@ -96,6 +104,15 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' savedModel <- read.ml(path) #' summary(savedModel) #' +#' # note that the default string encoding is different from R's glm +#' model2 <- glm(Freq ~ Sex + Age, family = "gaussian", data = t) +#' summary(model2) +#' # use stringIndexerOrderType = "alphabetDesc" to force string encoding +#' # to be consistent with R +#' model3 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian", +#' stringIndexerOrderType = "alphabetDesc") +#' summary(model3) +#' #' # fit tweedie model #' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", #' var.power = 1.2, link.power = 0) @@ -110,8 +127,12 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { + regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc"), + offsetCol = NULL) { + stringIndexerOrderType <- match.arg(stringIndexerOrderType) if (is.character(family)) { # Handle when family = "tweedie" if (tolower(family) == "tweedie") { @@ -141,11 +162,19 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), weightCol <- as.character(weightCol) } + if (!is.null(offsetCol)) { + offsetCol <- as.character(offsetCol) + if (nchar(offsetCol) == 0) { + offsetCol <- NULL + } + } + # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, tol, as.integer(maxIter), weightCol, regParam, - as.double(var.power), as.double(link.power)) + as.double(var.power), as.double(link.power), + stringIndexerOrderType, offsetCol) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -167,6 +196,14 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @param maxit integer giving the maximal number of IRLS iterations. #' @param var.power the index of the power variance function in the Tweedie family. #' @param link.power the index of the power link function in the Tweedie family. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets +#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -182,9 +219,14 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, - var.power = 0.0, link.power = 1.0 - var.power) { + var.power = 0.0, link.power = 1.0 - var.power, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc"), + offsetCol = NULL) { spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, - var.power = var.power, link.power = link.power) + var.power = var.power, link.power = link.power, + stringIndexerOrderType = stringIndexerOrderType, + offsetCol = offsetCol) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). @@ -418,6 +460,12 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. #' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg @@ -443,10 +491,14 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, aggregationDepth = 2) { + function(data, formula, aggregationDepth = 2, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc")) { + stringIndexerOrderType <- match.arg(stringIndexerOrderType) formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf, as.integer(aggregationDepth)) + "fit", formula, data@sdf, as.integer(aggregationDepth), + stringIndexerOrderType) new("AFTSurvivalRegressionModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 82279be6fbe7..33c4653f4c18 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) +#' S4 class that represents a DecisionTreeRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel +#' @export +#' @note DecisionTreeRegressionModel since 2.3.0 +setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a DecisionTreeClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel +#' @export +#' @note DecisionTreeClassificationModel since 2.3.0 +setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) + # Create the summary of a tree ensemble model (eg. Random Forest, GBT) summary.treeEnsemble <- function(model) { jobj <- model@jobj @@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) { invisible(x) } +# Create the summary of a decision tree model +summary.decisionTree <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + maxDepth = maxDepth, + jobj = jobj) +} + +# Prints the summary of decision tree models +print.summary.decisionTree <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + #' Gradient Boosted Tree Model for Regression and Classification #' #' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a @@ -120,6 +164,11 @@ print.summary.treeEnsemble <- function(x) { #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.gbt,SparkDataFrame,formula-method #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. @@ -161,7 +210,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, - checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -181,6 +231,7 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), new("GBTRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(lossType)) lossType <- "logistic" lossType <- match.arg(lossType, "logistic") jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", @@ -189,7 +240,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), as.numeric(stepSize), as.integer(minInstancesPerNode), as.numeric(minInfoGain), as.integer(checkpointInterval), lossType, seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("GBTClassificationModel", jobj = jobj) } ) @@ -330,6 +382,11 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -365,7 +422,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -386,6 +444,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo new("RandomForestRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(impurity)) impurity <- "gini" impurity <- match.arg(impurity, c("gini", "entropy")) jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", @@ -395,7 +454,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("RandomForestClassificationModel", jobj = jobj) } ) @@ -499,3 +559,207 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path function(object, path, overwrite = FALSE) { write_internal(object, path, overwrite) }) + +#' Decision Tree Model for Regression and Classification +#' +#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ +#' Decision Tree Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ +#' Decision Tree Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini", default is "gini". +#' @param seed integer seed for random number generation. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". +#' @param ... additional arguments passed to the method. +#' @aliases spark.decisionTree,SparkDataFrame,formula-method +#' @return \code{spark.decisionTree} returns a fitted Decision Tree model. +#' @rdname spark.decisionTree +#' @name spark.decisionTree +#' @export +#' @examples +#' \dontrun{ +#' # fit a Decision Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Decision Tree Classification Model +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification") +#' } +#' @note spark.decisionTree since 2.3.0 +setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), impurity, + as.integer(minInstancesPerNode), as.numeric(minInfoGain), + as.integer(checkpointInterval), seed, + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("DecisionTreeRegressionModel", jobj = jobj) + }, + classification = { + handleInvalid <- match.arg(handleInvalid) + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), impurity, + as.integer(minInstancesPerNode), as.numeric(minInfoGain), + as.integer(checkpointInterval), seed, + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) + new("DecisionTreeClassificationModel", jobj = jobj) + } + ) + }) + +# Get the summary of a Decision Tree Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeRegressionModel-method +#' @export +#' @note summary(DecisionTreeRegressionModel) since 2.3.0 +setMethod("summary", signature(object = "DecisionTreeRegressionModel"), + function(object) { + ans <- summary.decisionTree(object) + class(ans) <- "summary.DecisionTreeRegressionModel" + ans + }) + +# Prints the summary of Decision Tree Regression Model + +#' @param x summary object of Decision Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeRegressionModel since 2.3.0 +print.summary.DecisionTreeRegressionModel <- function(x, ...) { + print.summary.decisionTree(x) +} + +# Get the summary of a Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeClassificationModel-method +#' @export +#' @note summary(DecisionTreeClassificationModel) since 2.3.0 +setMethod("summary", signature(object = "DecisionTreeClassificationModel"), + function(object) { + ans <- summary.decisionTree(object) + class(ans) <- "summary.DecisionTreeClassificationModel" + ans + }) + +# Prints the summary of Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeClassificationModel since 2.3.0 +print.summary.DecisionTreeClassificationModel <- function(x, ...) { + print.summary.decisionTree(x) +} + +# Makes predictions from a Decision Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.decisionTree +#' @aliases predict,DecisionTreeRegressionModel-method +#' @export +#' @note predict(DecisionTreeRegressionModel) since 2.3.0 +setMethod("predict", signature(object = "DecisionTreeRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.decisionTree +#' @aliases predict,DecisionTreeClassificationModel-method +#' @export +#' @note predict(DecisionTreeClassificationModel) since 2.3.0 +setMethod("predict", signature(object = "DecisionTreeClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Decision Tree Regression or Classification model to the input path. + +#' @param object A fitted Decision Tree regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,DecisionTreeRegressionModel,character-method +#' @rdname spark.decisionTree +#' @export +#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 +setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,DecisionTreeClassificationModel,character-method +#' @rdname spark.decisionTree +#' @export +#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 +setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 5dfef8625061..a53c92c2c481 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -32,8 +32,9 @@ #' @rdname write.ml #' @name write.ml #' @export -#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, -#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, +#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, #' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, @@ -48,8 +49,9 @@ NULL #' @rdname predict #' @name predict #' @export -#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, -#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, +#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, #' @seealso \link{spark.kmeans}, #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear} @@ -110,6 +112,10 @@ read.ml <- function(path) { new("RandomForestRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) { + new("DecisionTreeRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) { + new("DecisionTreeClassificationModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { new("GBTRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index cb5bdb90175b..d1ed6833d5d0 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -23,18 +23,24 @@ #' Create a structType object that contains the metadata for a SparkDataFrame. Intended for #' use with createDataFrame and toDF. #' -#' @param x a structField object (created with the field() function) +#' @param x a structField object (created with the \code{structField} method). Since Spark 2.3, +#' this can be a DDL-formatted string, which is a comma separated list of field +#' definitions, e.g., "a INT, b STRING". #' @param ... additional structField objects #' @return a structType object #' @rdname structType #' @export #' @examples #'\dontrun{ -#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' schema <- structType(structField("a", "integer"), structField("c", "string"), #' structField("avg", "double")) #' df1 <- gapply(df, list("a", "c"), #' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, #' schema) +#' schema <- structType("a INT, c STRING, avg DOUBLE") +#' df1 <- gapply(df, list("a", "c"), +#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, +#' schema) #' } #' @note structType since 1.4.0 structType <- function(x, ...) { @@ -68,6 +74,23 @@ structType.structField <- function(x, ...) { structType(stObj) } +#' @rdname structType +#' @method structType character +#' @export +structType.character <- function(x, ...) { + if (!is.character(x)) { + stop("schema must be a DDL-formatted string.") + } + if (length(list(...)) > 0) { + stop("multiple DDL-formatted strings are not supported") + } + + stObj <- handledCallJStatic("org.apache.spark.sql.types.StructType", + "fromDDL", + x) + structType(stObj) +} + #' Print a Spark StructType. #' #' This function prints the contents of a StructType returned from the @@ -102,7 +125,7 @@ print.structType <- function(x, ...) { #' field1 <- structField("a", "integer") #' field2 <- structField("c", "string") #' field3 <- structField("avg", "double") -#' schema <- structType(field1, field2, field3) +#' schema <- structType(field1, field2, field3) #' df1 <- gapply(df, list("a", "c"), #' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, #' schema) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d0a12b7ecec6..81507ea7186a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -113,7 +113,7 @@ sparkR.stop <- function() { #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), #' c("one.jar", "two.jar", "three.jar"), -#' c("com.databricks:spark-avro_2.10:2.0.1")) +#' c("com.databricks:spark-avro_2.11:2.0.1")) #'} #' @note sparkR.init since 1.4.0 sparkR.init <- function( @@ -357,7 +357,7 @@ sparkRHive.init <- function(jsc = NULL) { #' sparkR.session("yarn-client", "SparkR", "/home/spark", #' list(spark.executor.memory="4g"), #' c("one.jar", "two.jar", "three.jar"), -#' c("com.databricks:spark-avro_2.10:2.0.1")) +#' c("com.databricks:spark-avro_2.11:2.0.1")) #' sparkR.session(spark.master = "yarn-client", spark.executor.memory = "4g") #'} #' @note sparkR.session since 2.0.0 @@ -535,6 +535,23 @@ cancelJobGroup <- function(sc, groupId) { } } +#' Set a human readable description of the current job. +#' +#' Set a description that is shown as a job description in UI. +#' +#' @param value The job description of the current job. +#' @rdname setJobDescription +#' @name setJobDescription +#' @examples +#'\dontrun{ +#' setJobDescription("This is an example job.") +#'} +#' @note setJobDescription since 2.3.0 +setJobDescription <- function(value) { + sc <- getSparkContext() + invisible(callJMethod(sc, "setJobDescription", value)) +} + sparkConfToSubmitOps <- new.env() sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index d78a10893f92..9a9fa84044ce 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -52,22 +52,23 @@ setMethod("crosstab", collect(dataFrame(sct)) }) -#' Calculate the sample covariance of two numerical columns of a SparkDataFrame. +#' @details +#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two numerical +#' columns of \emph{one} SparkDataFrame. #' #' @param colName1 the name of the first column #' @param colName2 the name of the second column #' @return The covariance of the two columns. #' #' @rdname cov -#' @name cov #' @aliases cov,SparkDataFrame-method #' @family stat functions #' @export #' @examples -#'\dontrun{ -#' df <- read.json("/path/to/file.json") -#' cov <- cov(df, "title", "gender") -#' } +#' +#' \dontrun{ +#' cov(df, "mpg", "hp") +#' cov(df, df$mpg, df$hp)} #' @note cov since 1.6.0 setMethod("cov", signature(x = "SparkDataFrame"), @@ -93,11 +94,10 @@ setMethod("cov", #' @family stat functions #' @export #' @examples -#'\dontrun{ -#' df <- read.json("/path/to/file.json") -#' corr <- corr(df, "title", "gender") -#' corr <- corr(df, "title", "gender", method = "pearson") -#' } +#' +#' \dontrun{ +#' corr(df, "mpg", "hp") +#' corr(df, "mpg", "hp", method = "pearson")} #' @note corr since 1.6.0 setMethod("corr", signature(x = "SparkDataFrame"), diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index d29af00affb9..91483a4d23d9 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -907,3 +907,15 @@ basenameSansExtFromUrl <- function(url) { isAtomicLengthOne <- function(x) { is.atomic(x) && length(x) == 1 } + +is_windows <- function() { + .Platform$OS.type == "windows" +} + +hadoop_home_set <- function() { + !identical(Sys.getenv("HADOOP_HOME"), "") +} + +windows_with_hadoop <- function() { + !is_windows() || hadoop_home_set() +} diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R new file mode 100644 index 000000000000..de47162d5325 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic tests for CRAN") + +test_that("create DataFrame from list or data.frame", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + i <- 4 + df <- createDataFrame(data.frame(dummy = 1:i)) + expect_equal(count(df), i) + + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + mtcarsdf <- createDataFrame(mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) + + sparkR.session.stop() +}) + +test_that("spark.glm and predict", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + + sparkR.session.stop() +}) diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R deleted file mode 100644 index e0802a9b02d1..000000000000 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ /dev/null @@ -1,212 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -library(testthat) - -context("MLlib tree-based algorithms") - -# Tests for MLlib tree-based algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) - -absoluteSparkPath <- function(x) { - sparkHome <- sparkR.conf("spark.home") - file.path(sparkHome, x) -} - -test_that("spark.gbt", { - # regression - data <- suppressWarnings(createDataFrame(longley)) - model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) - predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, - 63.221, 63.639, 64.989, 63.761, - 66.019, 67.857, 68.169, 66.513, - 68.655, 69.564, 69.331, 70.551), - tolerance = 1e-4) - stats <- summary(model) - expect_equal(stats$numTrees, 20) - expect_equal(stats$maxDepth, 5) - expect_equal(stats$formula, "Employed ~ .") - expect_equal(stats$numFeatures, 6) - expect_equal(length(stats$treeWeights), 20) - - modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) - - # classification - # label must be binary - GBTClassifier currently only supports binary classification. - iris2 <- iris[iris$Species != "virginica", ] - data <- suppressWarnings(createDataFrame(iris2)) - model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") - stats <- summary(model) - expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) - expect_equal(stats$maxDepth, 5) - expect_error(capture.output(stats), NA) - expect_true(length(capture.output(stats)) > 6) - predictions <- collect(predict(model, data))$prediction - # test string prediction values - expect_equal(length(grep("setosa", predictions)), 50) - expect_equal(length(grep("versicolor", predictions)), 50) - - modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) - - iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) - df <- suppressWarnings(createDataFrame(iris2)) - m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") - s <- summary(m) - # test numeric prediction values - expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) - expect_equal(s$numFeatures, 5) - expect_equal(s$numTrees, 20) - expect_equal(stats$maxDepth, 5) - - # spark.gbt classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), - source = "libsvm") - model <- spark.gbt(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 692) -}) - -test_that("spark.randomForest", { - # regression - data <- suppressWarnings(createDataFrame(longley)) - model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, - numTrees = 1) - - predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, - 63.221, 63.639, 64.989, 63.761, - 66.019, 67.857, 68.169, 66.513, - 68.655, 69.564, 69.331, 70.551), - tolerance = 1e-4) - - stats <- summary(model) - expect_equal(stats$numTrees, 1) - expect_equal(stats$maxDepth, 5) - expect_error(capture.output(stats), NA) - expect_true(length(capture.output(stats)) > 6) - - model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, - numTrees = 20, seed = 123) - predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, - 63.53160, 64.05470, 65.12710, 64.30450, - 66.70910, 67.86125, 68.08700, 67.21865, - 68.89275, 69.53180, 69.39640, 69.68250), - tolerance = 1e-4) - stats <- summary(model) - expect_equal(stats$numTrees, 20) - expect_equal(stats$maxDepth, 5) - - modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) - - # classification - data <- suppressWarnings(createDataFrame(iris)) - model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", - maxDepth = 5, maxBins = 16) - - stats <- summary(model) - expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) - expect_equal(stats$maxDepth, 5) - expect_error(capture.output(stats), NA) - expect_true(length(capture.output(stats)) > 6) - # Test string prediction values - predictions <- collect(predict(model, data))$prediction - expect_equal(length(grep("setosa", predictions)), 50) - expect_equal(length(grep("versicolor", predictions)), 50) - - modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) - - # Test numeric response variable - labelToIndex <- function(species) { - switch(as.character(species), - setosa = 0.0, - versicolor = 1.0, - virginica = 2.0 - ) - } - iris$NumericSpecies <- lapply(iris$Species, labelToIndex) - data <- suppressWarnings(createDataFrame(iris[-5])) - model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", - maxDepth = 5, maxBins = 16) - stats <- summary(model) - expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) - expect_equal(stats$maxDepth, 5) - - # Test numeric prediction values - predictions <- collect(predict(model, data))$prediction - expect_equal(length(grep("1.0", predictions)), 50) - expect_equal(length(grep("2.0", predictions)), 50) - - # spark.randomForest classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), - source = "libsvm") - model <- spark.randomForest(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 4) -}) - -sparkR.session.stop() diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3a318b71ea06..2e31dc5f728c 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,8 +30,50 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) +# Waits indefinitely for a socket connecion by default. +selectTimeout <- NULL + while (TRUE) { - ready <- socketSelect(list(inputCon)) + ready <- socketSelect(list(inputCon), timeout = selectTimeout) + + # Note that the children should be terminated in the parent. If each child terminates + # itself, it appears that the resource is not released properly, that causes an unexpected + # termination of this daemon due to, for example, running out of file descriptors + # (see SPARK-21093). Therefore, the current implementation tries to retrieve children + # that are exited (but not terminated) and then sends a kill signal to terminate them properly + # in the parent. + # + # There are two paths that it attempts to send a signal to terminate the children in the parent. + # + # 1. Every second if any socket connection is not available and if there are child workers + # running. + # 2. Right after a socket connection is available. + # + # In other words, the parent attempts to send the signal to the children every second if + # any worker is running or right before launching other worker children from the following + # new socket connection. + + # The process IDs of exited children are returned below. + children <- parallel:::selectChildren(timeout = 0) + + if (is.integer(children)) { + lapply(children, function(child) { + # This should be the PIDs of exited children. Otherwise, this returns raw bytes if any data + # was sent from this child. In this case, we discard it. + pid <- parallel:::readChild(child) + if (is.integer(pid)) { + # This checks if the data from this child is the same pid of this selected child. + if (child == pid) { + # If so, we terminate this child. + tools::pskill(child, tools::SIGUSR1) + } + } + }) + } else if (is.null(children)) { + # If it is NULL, there are no children. Waits indefinitely for a socket connecion. + selectTimeout <- NULL + } + if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -44,12 +86,15 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { + # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Set SIGUSR1 so that child can exit - tools::pskill(Sys.getpid(), tools::SIGUSR1) + # Note that this mcexit does not fully terminate this child. parallel:::mcexit(0L) + } else { + # Forking succeeded and we need to check if they finished their jobs every second. + selectTimeout <- 1 } } } diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/tests/fulltests/jarTest.R similarity index 96% rename from R/pkg/inst/tests/testthat/jarTest.R rename to R/pkg/tests/fulltests/jarTest.R index c9615c8d4faf..e2241e03b55f 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/tests/fulltests/jarTest.R @@ -16,7 +16,7 @@ # library(SparkR) -sc <- sparkR.session() +sc <- sparkR.session(master = "local[1]") helloTest <- SparkR:::callJStatic("sparkrtest.DummyClass", "helloWorld", diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/tests/fulltests/packageInAJarTest.R similarity index 96% rename from R/pkg/inst/tests/testthat/packageInAJarTest.R rename to R/pkg/tests/fulltests/packageInAJarTest.R index 4bc935c79eb0..ac706261999f 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/tests/fulltests/packageInAJarTest.R @@ -17,7 +17,7 @@ library(SparkR) library(sparkPackageTest) -sparkR.session() +sparkR.session(master = "local[1]") run1 <- myfunc(5L) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_Serde.R rename to R/pkg/tests/fulltests/test_Serde.R index b5f6f1b54fa8..6bbd201bf1d8 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -17,7 +17,7 @@ context("SerDe functionality") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", 1L) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_Windows.R rename to R/pkg/tests/fulltests/test_Windows.R index 1d777ddb286d..b2ec6c67311d 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/tests/fulltests/test_Windows.R @@ -17,7 +17,7 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { - if (.Platform$OS.type != "windows") { + if (!is_windows()) { skip("This test is only for Windows, skipped") } diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/tests/fulltests/test_binaryFile.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_binaryFile.R rename to R/pkg/tests/fulltests/test_binaryFile.R index b5c279e3156e..758b174b8787 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/tests/fulltests/test_binaryFile.R @@ -18,7 +18,7 @@ context("functions on binary files") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_binary_function.R rename to R/pkg/tests/fulltests/test_binary_function.R index 59cb2e620440..442bed509bb1 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -18,7 +18,7 @@ context("binary functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/tests/fulltests/test_broadcast.R similarity index 92% rename from R/pkg/inst/tests/testthat/test_broadcast.R rename to R/pkg/tests/fulltests/test_broadcast.R index 65f204d096f4..fc2c7c2deb82 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/tests/fulltests/test_broadcast.R @@ -18,7 +18,7 @@ context("broadcast variables") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data @@ -27,7 +27,7 @@ rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - randomMatBr <- broadcast(sc, randomMat) + randomMatBr <- broadcastRDD(sc, randomMat) useBroadcast <- function(x) { sum(SparkR:::value(randomMatBr) * x) diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/tests/fulltests/test_client.R similarity index 95% rename from R/pkg/inst/tests/testthat/test_client.R rename to R/pkg/tests/fulltests/test_client.R index 0cf25fe1dbf3..de624b572cc2 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/tests/fulltests/test_client.R @@ -37,7 +37,7 @@ test_that("multiple packages don't produce a warning", { test_that("sparkJars sparkPackages as character vectors", { args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", - c("com.databricks:spark-avro_2.10:2.0.1")) + c("com.databricks:spark-avro_2.11:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") - expect_match(args, "--packages com.databricks:spark-avro_2.10:2.0.1") + expect_match(args, "--packages com.databricks:spark-avro_2.11:2.0.1") }) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/tests/fulltests/test_context.R similarity index 93% rename from R/pkg/inst/tests/testthat/test_context.R rename to R/pkg/tests/fulltests/test_context.R index c64fe6edcd49..77635c5a256b 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -56,7 +56,7 @@ test_that("Check masked functions", { test_that("repeatedly starting and stopping SparkR", { for (i in 1:4) { - sc <- suppressWarnings(sparkR.init()) + sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) expect_equal(countRDD(rdd), 20) suppressWarnings(sparkR.stop()) @@ -65,7 +65,7 @@ test_that("repeatedly starting and stopping SparkR", { test_that("repeatedly starting and stopping SparkSession", { for (i in 1:4) { - sparkR.session(enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) df <- createDataFrame(data.frame(dummy = 1:i)) expect_equal(count(df), i) sparkR.session.stop() @@ -73,12 +73,12 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { - sc <- sparkR.sparkContext() # sc should get id 0 + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 sparkR.session.stop() - sc <- sparkR.sparkContext() # sc should get id 0 again + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -96,10 +96,11 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - sc <- sparkR.sparkContext() + sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() + setJobDescription("job description") suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) suppressWarnings(cancelJobGroup(sc, "groupId")) @@ -108,7 +109,7 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) @@ -161,14 +162,14 @@ test_that("sparkJars sparkPackages as comma-separated strings", { }) test_that("spark.lapply should perform simple transforms", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) doubled <- spark.lapply(1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) sparkR.session.stop() }) test_that("add and get file to be downloaded with Spark job on every node", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") filename <- basename(path) diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/tests/fulltests/test_includePackage.R similarity index 95% rename from R/pkg/inst/tests/testthat/test_includePackage.R rename to R/pkg/tests/fulltests/test_includePackage.R index 563ea298c2dd..f4ea0d1b5cb2 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/tests/fulltests/test_includePackage.R @@ -18,7 +18,7 @@ context("include R packages") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/tests/fulltests/test_jvm_api.R similarity index 93% rename from R/pkg/inst/tests/testthat/test_jvm_api.R rename to R/pkg/tests/fulltests/test_jvm_api.R index 7348c893d0af..8b3b4f73de17 100644 --- a/R/pkg/inst/tests/testthat/test_jvm_api.R +++ b/R/pkg/tests/fulltests/test_jvm_api.R @@ -17,7 +17,7 @@ context("JVM API") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("Create and call methods on object", { jarr <- sparkR.newJObject("java.util.ArrayList") diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R similarity index 65% rename from R/pkg/inst/tests/testthat/test_mllib_classification.R rename to R/pkg/tests/fulltests/test_mllib_classification.R index cbc708718286..a4d0397236d1 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib classification algorithms, except for tree-based algorithms") # Tests for MLlib classification algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -38,9 +38,8 @@ test_that("spark.svmLinear", { expect_true(class(summary$coefficients[, 1]) == "numeric") coefs <- summary$coefficients[, "Estimate"] - expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085) + expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085) expect_true(all(abs(coefs - expected_coefs) < 0.1)) - expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2) # Test prediction with string label prediction <- predict(model, training) @@ -50,15 +49,17 @@ test_that("spark.svmLinear", { expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load - modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # Test prediction with numeric label label <- c(0.0, 0.0, 0.0, 1.0, 1.0) @@ -69,6 +70,20 @@ test_that("spark.svmLinear", { prediction <- collect(select(predict(model, df), "prediction")) expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "list") + }) test_that("spark.logit", { @@ -128,15 +143,17 @@ test_that("spark.logit", { expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) # Test model save and load - modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # R code to reproduce the result. # nolint start @@ -220,6 +237,61 @@ test_that("spark.logit", { model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") prediction2 <- collect(select(predict(model2, df2), "prediction")) expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) + + # Test binomial logistic regression againt two classes with upperBoundsOnCoefficients + # and upperBoundsOnIntercepts + u <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) + model <- spark.logit(training, Species ~ ., upperBoundsOnCoefficients = u, + upperBoundsOnIntercepts = 1.0) + summary <- summary(model) + coefsR <- c(-11.13331, 1.00000, 0.00000, 1.00000, 0.00000) + coefs <- summary$coefficients[, "Estimate"] + expect_true(all(abs(coefsR - coefs) < 0.1)) + # Test upperBoundsOnCoefficients should be matrix + expect_error(spark.logit(training, Species ~ ., upperBoundsOnCoefficients = as.array(c(1, 2)), + upperBoundsOnIntercepts = 1.0)) + + # Test binomial logistic regression againt two classes with lowerBoundsOnCoefficients + # and lowerBoundsOnIntercepts + l <- matrix(c(0.0, -1.0, 0.0, -1.0), nrow = 1, ncol = 4) + model <- spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = l, + lowerBoundsOnIntercepts = 0.0) + summary <- summary(model) + coefsR <- c(0, 0, -1, 0, 1.902192) + coefs <- summary$coefficients[, "Estimate"] + expect_true(all(abs(coefsR - coefs) < 0.1)) + # Test lowerBoundsOnCoefficients should be matrix + expect_error(spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = as.array(c(1, 2)), + lowerBoundsOnIntercepts = 0.0)) + + # Test multinomial logistic regression with lowerBoundsOnCoefficients + # and lowerBoundsOnIntercepts + l <- matrix(c(0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0), nrow = 2, ncol = 4) + model <- spark.logit(training, Species ~ ., family = "multinomial", + lowerBoundsOnCoefficients = l, + lowerBoundsOnIntercepts = as.array(c(0.0, 0.0))) + summary <- summary(model) + versicolorCoefsR <- c(42.639465, 7.258104, 14.330814, 16.298243, 11.716429) + virginicaCoefsR <- c(0.0002970796, 4.79274, 7.65047, 25.72793, 30.0021) + versicolorCoefs <- summary$coefficients[, "versicolor"] + virginicaCoefs <- summary$coefficients[, "virginica"] + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.logit(traindf, clicked ~ ., regParam = 0.5) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.logit(traindf, clicked ~ ., regParam = 0.5, handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + }) test_that("spark.mlp", { @@ -243,19 +315,21 @@ test_that("spark.mlp", { expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load - modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - - expect_equal(summary2$numOfInputs, 4) - expect_equal(summary2$numOfOutputs, 3) - expect_equal(summary2$layers, c(4, 5, 4, 3)) - expect_equal(length(summary2$weights), 64) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) + expect_equal(summary2$layers, c(4, 5, 4, 3)) + expect_equal(length(summary2$weights), 64) + + unlink(modelPath) + } # Test default parameter model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) @@ -299,6 +373,21 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3)) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "list") + }) test_that("spark.naiveBayes", { @@ -354,16 +443,18 @@ test_that("spark.naiveBayes", { "Yes", "Yes", "No", "No")) # Test model save/load - modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - expect_equal(s$apriori, s2$apriori) - expect_equal(s$tables, s2$tables) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + } # Test e1071::naiveBayes if (requireNamespace("e1071", quietly = TRUE)) { @@ -380,6 +471,20 @@ test_that("spark.naiveBayes", { expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) expect_equal(sum(s$apriori), 1) expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0, handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/tests/fulltests/test_mllib_clustering.R similarity index 79% rename from R/pkg/inst/tests/testthat/test_mllib_clustering.R rename to R/pkg/tests/fulltests/test_mllib_clustering.R index 1661e987b730..4110e13da494 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/tests/fulltests/test_mllib_clustering.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib clustering algorithms") # Tests for MLlib clustering algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -53,18 +53,20 @@ test_that("spark.bisectingKmeans", { c(0, 1, 2, 3)) # Test model save/load - modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } }) test_that("spark.gaussianMixture", { @@ -125,18 +127,20 @@ test_that("spark.gaussianMixture", { expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) # Test model save/load - modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$lambda, stats2$lambda) - expect_equal(unlist(stats$mu), unlist(stats2$mu)) - expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) - expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$lambda, stats2$lambda) + expect_equal(unlist(stats$mu), unlist(stats2$mu)) + expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) + + unlink(modelPath) + } }) test_that("spark.kmeans", { @@ -171,18 +175,20 @@ test_that("spark.kmeans", { expect_true(class(summary.model$coefficients[1, ]) == "numeric") # Test model save/load - modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } # Test Kmeans on dataset that is sensitive to seed value col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) @@ -236,22 +242,24 @@ test_that("spark.lda with libsvm", { expect_true(logPrior <= 0 & !is.na(logPrior)) # Test model save/load - modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_true(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_equal(vocabulary, stats2$vocabulary) - expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) - expect_equal(logPrior, stats2$logPrior) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_true(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) + expect_equal(logPrior, stats2$logPrior) + + unlink(modelPath) + } }) test_that("spark.lda with text input", { diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R similarity index 85% rename from R/pkg/inst/tests/testthat/test_mllib_fpm.R rename to R/pkg/tests/fulltests/test_mllib_fpm.R index c38f1133897d..69dda52f0c27 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib frequent pattern mining") # Tests for MLlib frequent pattern mining algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.fpGrowth", { data <- selectExpr(createDataFrame(data.frame(items = c( @@ -62,15 +62,17 @@ test_that("spark.fpGrowth", { expect_equivalent(expected_predictions, collect(predict(model, new_data))) - modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") - write.ml(model, modelPath, overwrite = TRUE) - loaded_model <- read.ml(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) - expect_equivalent( - itemsets, - collect(spark.freqItemsets(loaded_model))) + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) - unlink(modelPath) + unlink(modelPath) + } model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) expect_equal( diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/tests/fulltests/test_mllib_recommendation.R similarity index 59% rename from R/pkg/inst/tests/testthat/test_mllib_recommendation.R rename to R/pkg/tests/fulltests/test_mllib_recommendation.R index 6b1040db9305..4d919c9d746b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R +++ b/R/pkg/tests/fulltests/test_mllib_recommendation.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib recommendation algorithms") # Tests for MLlib recommendation algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.als", { data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -37,29 +37,31 @@ test_that("spark.als", { tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats2$rating, "score") - userFactors <- collect(stats$userFactors) - itemFactors <- collect(stats$itemFactors) - userFactors2 <- collect(stats2$userFactors) - itemFactors2 <- collect(stats2$itemFactors) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) - orderUser <- order(userFactors$id) - orderUser2 <- order(userFactors2$id) - expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) - expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) - orderItem <- order(itemFactors$id) - orderItem2 <- order(itemFactors2$id) - expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) - expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) - unlink(modelPath) + unlink(modelPath) + } }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R similarity index 81% rename from R/pkg/inst/tests/testthat/test_mllib_regression.R rename to R/pkg/tests/fulltests/test_mllib_regression.R index 3e9ad7719807..23daca75fcc2 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/tests/fulltests/test_mllib_regression.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib regression algorithms, except for tree-based algorithms") # Tests for MLlib regression algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("formula of spark.glm", { training <- suppressWarnings(createDataFrame(iris)) @@ -173,6 +173,14 @@ test_that("spark.glm summary", { expect_equal(stats$df.residual, rStats$df.residual) expect_equal(stats$aic, rStats$aic) + # Test spark.glm works with offset + training <- suppressWarnings(createDataFrame(iris)) + stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = poisson(), offsetCol = "Petal_Length")) + rStats <- suppressWarnings(summary(glm(Sepal.Width ~ Sepal.Length + Species, + data = iris, family = poisson(), offset = iris$Petal.Length))) + expect_true(all(abs(rStats$coefficients - stats$coefficients) < 1e-3)) + # Test summary works on base GLM models baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) baseSummary <- summary(baseModel) @@ -367,6 +375,49 @@ test_that("glm save/load", { unlink(modelPath) }) +test_that("spark.glm and glm with string encoding", { + t <- as.data.frame(Titanic, stringsAsFactors = FALSE) + df <- createDataFrame(t) + + # base R + rm <- stats::glm(Freq ~ Sex + Age, family = "gaussian", data = t) + # spark.glm with default stringIndexerOrderType = "frequencyDesc" + sm0 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") + # spark.glm with stringIndexerOrderType = "alphabetDesc" + sm1 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian", + stringIndexerOrderType = "alphabetDesc") + # glm with stringIndexerOrderType = "alphabetDesc" + sm2 <- glm(Freq ~ Sex + Age, family = "gaussian", data = df, + stringIndexerOrderType = "alphabetDesc") + + rStats <- summary(rm) + rCoefs <- rStats$coefficients + sStats <- lapply(list(sm0, sm1, sm2), summary) + # order by coefficient size since column rendering may be different + o <- order(rCoefs[, 1]) + + # default encoding does not produce same results as R + expect_false(all(abs(rCoefs[o, ] - sStats[[1]]$coefficients[o, ]) < 1e-4)) + + # all estimates should be the same as R with stringIndexerOrderType = "alphabetDesc" + test <- lapply(sStats[2:3], function(stats) { + expect_true(all(abs(rCoefs[o, ] - stats$coefficients[o, ]) < 1e-4)) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + }) + + # fitted values should be equal regardless of string encoding + rVals <- predict(rm, t) + test <- lapply(list(sm0, sm1, sm2), function(sm) { + vals <- collect(select(predict(sm, df), "prediction")) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + }) +}) + test_that("spark.isoreg", { label <- c(7.0, 5.0, 3.0, 5.0, 1.0) feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) @@ -389,14 +440,16 @@ test_that("spark.isoreg", { expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) # Test model save/load - modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - expect_equal(result, summary(model2)) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + expect_equal(result, summary(model2)) + + unlink(modelPath) + } }) test_that("spark.survreg", { @@ -438,17 +491,19 @@ test_that("spark.survreg", { 2.390146, 2.891269, 2.891269), tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - coefs2 <- as.vector(stats2$coefficients[, 1]) - expect_equal(coefs, coefs2) - expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + coefs2 <- as.vector(stats2$coefficients[, 1]) + expect_equal(coefs, coefs2) + expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) + + unlink(modelPath) + } # Test survival::survreg if (requireNamespace("survival", quietly = TRUE)) { @@ -458,6 +513,25 @@ test_that("spark.survreg", { model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), NA) expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) + + # Test stringIndexerOrderType + rData <- as.data.frame(rData) + rData$sex2 <- c("female", "male")[rData$sex + 1] + df <- createDataFrame(rData) + expect_error( + rModel <- survival::survreg(survival::Surv(time, status) ~ x + sex2, rData), NA) + rCoefs <- as.numeric(summary(rModel)$table[, 1]) + model <- spark.survreg(df, Surv(time, status) ~ x + sex2) + coefs <- as.vector(summary(model)$coefficients[, 1]) + o <- order(rCoefs) + # stringIndexerOrderType = "frequencyDesc" produces different estimates from R + expect_false(all(abs(rCoefs[o] - coefs[o]) < 1e-4)) + + # stringIndexerOrderType = "alphabetDesc" produces the same estimates as R + model <- spark.survreg(df, Surv(time, status) ~ x + sex2, + stringIndexerOrderType = "alphabetDesc") + coefs <- as.vector(summary(model)$coefficients[, 1]) + expect_true(all(abs(rCoefs[o] - coefs[o]) < 1e-4)) } }) diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/tests/fulltests/test_mllib_stat.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_mllib_stat.R rename to R/pkg/tests/fulltests/test_mllib_stat.R index beb148e7702f..1600833a5d03 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_stat.R +++ b/R/pkg/tests/fulltests/test_mllib_stat.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib statistics algorithms") # Tests for MLlib statistics algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.kstest", { data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R new file mode 100644 index 000000000000..facd3a941cf1 --- /dev/null +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -0,0 +1,365 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib tree-based algorithms") + +# Tests for MLlib tree-based algorithms in SparkR +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +absoluteSparkPath <- function(x) { + sparkHome <- sparkR.conf("spark.home") + file.path(sparkHome, x) +} + +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification", seed = 12) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification", seed = 12) + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) + expect_equal(stats$maxDepth, 5) + + # spark.gbt classification can work on libsvm data + if (windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification", seed = 12) + expect_equal(summary(model)$numFeatures, 692) + } + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.gbt(traindf, clicked ~ ., type = "classification", seed = 23) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep", + seed = 23) + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") +}) + +test_that("spark.randomForest", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 1, seed = 1) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$numTrees, 1) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 20, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, + 63.53160, 64.05470, 65.12710, 64.30450, + 66.70910, 67.86125, 68.08700, 67.21865, + 68.89275, 69.53180, 69.39640, 69.68250), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } + + # classification + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16, seed = 123) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16, seed = 123) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10, seed = 123) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10, + handleInvalid = "keep", seed = 123) + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + + # spark.randomForest classification can work on libsvm data + if (windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification", seed = 123) + expect_equal(summary(model)$numFeatures, 4) + } +}) + +test_that("spark.decisionTree", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + seed = 42) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + + unlink(modelPath) + } + + # classification + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16, seed = 43) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.decisionTree(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16, seed = 44) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$maxDepth, 5) + + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.decisionTree classification can work on libsvm data + if (windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.decisionTree(data, label ~ features, "classification", seed = 45) + expect_equal(summary(model)$numFeatures, 4) + } + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.decisionTree(traindf, clicked ~ ., type = "classification", + maxDepth = 5, maxBins = 16, seed = 46) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.decisionTree(traindf, clicked ~ ., type = "classification", + maxDepth = 5, maxBins = 16, handleInvalid = "keep", seed = 46) + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/tests/fulltests/test_parallelize_collect.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_parallelize_collect.R rename to R/pkg/tests/fulltests/test_parallelize_collect.R index 55972e1ba469..3d122ccaf448 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/tests/fulltests/test_parallelize_collect.R @@ -33,7 +33,7 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_rdd.R rename to R/pkg/tests/fulltests/test_rdd.R index b72c801dd958..6ee1fceffd82 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -18,7 +18,7 @@ context("basic RDD functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data @@ -40,8 +40,8 @@ test_that("first on RDD", { }) test_that("count and length on RDD", { - expect_equal(countRDD(rdd), 10) - expect_equal(lengthRDD(rdd), 10) + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/tests/fulltests/test_shuffle.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_shuffle.R rename to R/pkg/tests/fulltests/test_shuffle.R index d38efab0fd1d..98300c67c415 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/tests/fulltests/test_shuffle.R @@ -18,7 +18,7 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_sparkR.R rename to R/pkg/tests/fulltests/test_sparkR.R diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R similarity index 87% rename from R/pkg/inst/tests/testthat/test_sparkSQL.R rename to R/pkg/tests/fulltests/test_sparkSQL.R index 08296354ca7e..4e62be9b4d61 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -61,7 +61,11 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- sparkR.session() +sparkSession <- if (windows_with_hadoop()) { + sparkR.session(master = sparkRTestMaster) + } else { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", @@ -96,6 +100,10 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) +if (is_windows()) { + Sys.setenv(TZ = "GMT") +} + test_that("calling sparkRSQL.init returns existing SQL context", { sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) @@ -138,6 +146,13 @@ test_that("structType and structField", { expect_is(testSchema, "structType") expect_is(testSchema$fields()[[2]], "structField") expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") + + testSchema <- structType("a STRING, b INT") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") + + expect_error(structType("A stri"), "DataType stri is not supported.") }) test_that("structField type strings", { @@ -312,51 +327,53 @@ test_that("createDataFrame uses files for large objects", { }) test_that("read/write csv as DataFrame", { - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "NA,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - # default "header" is false, inferSchema to handle "year" as "int" - df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") - expect_equal(count(df), 4) - expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) - expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), - sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) - - # since "year" is "int", let's skip the NA values - withoutna <- na.omit(df, how = "any", cols = "year") - expect_equal(count(withoutna), 3) - - unlink(csvPath) - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "Empty,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") - expect_equal(count(df2), 4) - withoutna2 <- na.omit(df2, how = "any", cols = "year") - expect_equal(count(withoutna2), 3) - expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) - - # writing csv file - csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") - write.df(df2, path = csvPath2, "csv", header = "true") - df3 <- read.df(csvPath2, "csv", header = "true") - expect_equal(nrow(df3), nrow(df2)) - expect_equal(colnames(df3), colnames(df2)) - csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) - expect_equal(colnames(df3), colnames(csv)) - - unlink(csvPath) - unlink(csvPath2) + if (windows_with_hadoop()) { + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "NA,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + # default "header" is false, inferSchema to handle "year" as "int" + df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") + expect_equal(count(df), 4) + expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) + expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), + sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) + + # since "year" is "int", let's skip the NA values + withoutna <- na.omit(df, how = "any", cols = "year") + expect_equal(count(withoutna), 3) + + unlink(csvPath) + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "Empty,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") + expect_equal(count(df2), 4) + withoutna2 <- na.omit(df2, how = "any", cols = "year") + expect_equal(count(withoutna2), 3) + expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) + + # writing csv file + csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") + write.df(df2, path = csvPath2, "csv", header = "true") + df3 <- read.df(csvPath2, "csv", header = "true") + expect_equal(nrow(df3), nrow(df2)) + expect_equal(colnames(df3), colnames(df2)) + csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) + expect_equal(colnames(df3), colnames(csv)) + + unlink(csvPath) + unlink(csvPath2) + } }) test_that("Support other types for options", { @@ -579,48 +596,50 @@ test_that("Collect DataFrame with complex types", { }) test_that("read/write json files", { - # Test read.df - df <- read.df(jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test read.df with a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(jsonPath, "json", schema) - expect_is(df1, "SparkDataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Test loadDF - df2 <- loadDF(jsonPath, "json", schema) - expect_is(df2, "SparkDataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) - - # Test read.json - df <- read.json(jsonPath) - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test write.df - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") - write.df(df, jsonPath2, "json", mode = "overwrite") - - # Test write.json - jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") - write.json(df, jsonPath3) - - # Test read.json()/jsonFile() works with multiple input paths - jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) - expect_is(jsonDF1, "SparkDataFrame") - expect_equal(count(jsonDF1), 6) - # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) - expect_is(jsonDF2, "SparkDataFrame") - expect_equal(count(jsonDF2), 6) - - unlink(jsonPath2) - unlink(jsonPath3) + if (windows_with_hadoop()) { + # Test read.df + df <- read.df(jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(jsonPath, "json", schema) + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(jsonPath, "json", schema) + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json + df <- read.json(jsonPath) + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test write.df + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") + write.df(df, jsonPath2, "json", mode = "overwrite") + + # Test write.json + jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) + expect_is(jsonDF1, "SparkDataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) + expect_is(jsonDF2, "SparkDataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) + unlink(jsonPath3) + } }) test_that("read/write json files - compression option", { @@ -651,24 +670,27 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { + count <- count(listTables()) + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") - expect_equal(length(tableNames()), 1) - expect_equal(length(tableNames("default")), 1) + expect_equal(length(tableNames()), count + 1) + expect_equal(length(tableNames("default")), count + 1) + tables <- listTables() - expect_equal(count(tables), 1) + expect_equal(count(tables), count + 1) expect_equal(count(tables()), count(tables)) expect_true("tableName" %in% colnames(tables())) expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) tables <- listTables() - expect_equal(count(tables), 2) + expect_equal(count(tables), count + 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) tables <- listTables() - expect_equal(count(tables), 0) + expect_equal(count(tables), count + 0) }) test_that( @@ -705,33 +727,35 @@ test_that("test cache, uncache and clearCache", { }) test_that("insertInto() on a registered table", { - df <- read.df(jsonPath, "json") - write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(parquetPath, "parquet") - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - writeLines(lines, jsonPath2) - df2 <- read.df(jsonPath2, "json") - write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(parquetPath2, "parquet") - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1") - expect_equal(count(sql("select * from table1")), 5) - expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - expect_true(dropTempView("table1")) - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql("select * from table1")), 2) - expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - expect_true(dropTempView("table1")) - - unlink(jsonPath2) - unlink(parquetPath2) + if (windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- read.df(jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(parquetPath2, "parquet") + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_equal(count(sql("select * from table1")), 5) + expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") + expect_true(dropTempView("table1")) + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_equal(count(sql("select * from table1")), 2) + expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") + expect_true(dropTempView("table1")) + + unlink(jsonPath2) + unlink(parquetPath2) + } }) test_that("tableToDF() returns a new DataFrame", { @@ -911,14 +935,16 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", }) test_that("setCheckpointDir(), checkpoint() on a DataFrame", { - checkpointDir <- file.path(tempdir(), "cproot") - expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) - - setCheckpointDir(checkpointDir) - df <- read.json(jsonPath) - df <- checkpoint(df) - expect_is(df, "SparkDataFrame") - expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + if (windows_with_hadoop()) { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + } }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { @@ -1090,6 +1116,20 @@ test_that("sample on a DataFrame", { sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + # Different arguments + df <- createDataFrame(as.list(seq(10))) + expect_equal(count(sample(df, fraction = 0.5, seed = 3)), 4) + expect_equal(count(sample(df, withReplacement = TRUE, fraction = 0.5, seed = 3)), 2) + expect_equal(count(sample(df, fraction = 1.0)), 10) + expect_equal(count(sample(df, fraction = 1L)), 10) + expect_equal(count(sample(df, FALSE, fraction = 1.0)), 10) + + expect_error(sample(df, fraction = "a"), "fraction must be numeric") + expect_error(sample(df, "a", fraction = 0.1), "however, got character") + expect_error(sample(df, fraction = 1, seed = NA), "seed must not be NULL or NA; however, got NA") + expect_error(sample(df, fraction = -1.0), + "illegal argument - requirement failed: Sampling fraction \\(-1.0\\)") + # nolint start # Test base::sample is working #expect_equal(length(sample(1:12)), 12) @@ -1187,6 +1227,16 @@ test_that("select with column", { expect_equal(columns(df4), c("name", "age")) expect_equal(count(df4), 3) + # Test select with alias + df5 <- alias(df, "table") + + expect_equal(columns(select(df5, column("table.name"))), "name") + expect_equal(columns(select(df5, "table.name")), "name") + + # Test that stats::alias is not masked + expect_is(alias(aov(yield ~ block + N * P * K, npk)), "listof") + + expect_error(select(df, c("name", "age"), "name"), "To select multiple columns, use a character vector or list for col") }) @@ -1276,45 +1326,47 @@ test_that("column calculation", { }) test_that("test HiveContext", { - setHiveContext(sc) - - schema <- structType(structField("name", "string"), structField("age", "integer"), - structField("height", "float")) - createTable("people", source = "json", schema = schema) - df <- read.df(jsonPathNa, "json", schema) - insertInto(df, "people") - expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) - sql("DROP TABLE people") - - df <- createTable("json", jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - df2 <- sql("select * from json") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "json2", "json", "append", path = jsonPath2) - df3 <- sql("select * from json2") - expect_is(df3, "SparkDataFrame") - expect_equal(count(df3), 3) - unlink(jsonPath2) - - hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "hivetestbl", path = hivetestDataPath) - df4 <- sql("select * from hivetestbl") - expect_is(df4, "SparkDataFrame") - expect_equal(count(df4), 3) - unlink(hivetestDataPath) - - parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) - df5 <- sql("select * from parquetest") - expect_is(df5, "SparkDataFrame") - expect_equal(count(df5), 3) - unlink(parquetDataPath) - - unsetHiveContext() + if (windows_with_hadoop()) { + setHiveContext(sc) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + df2 <- sql("select * from json") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "json2", "json", "append", path = jsonPath2) + df3 <- sql("select * from json2") + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) + unlink(jsonPath2) + + hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "hivetestbl", path = hivetestDataPath) + df4 <- sql("select * from hivetestbl") + expect_is(df4, "SparkDataFrame") + expect_equal(count(df4), 3) + unlink(hivetestDataPath) + + parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) + df5 <- sql("select * from parquetest") + expect_is(df5, "SparkDataFrame") + expect_equal(count(df5), 3) + unlink(parquetDataPath) + + unsetHiveContext() + } }) test_that("column operators", { @@ -1351,6 +1403,8 @@ test_that("column functions", { c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) c22 <- not(c) + c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1366,6 +1420,11 @@ test_that("column functions", { expect_equal(collect(df2)[[3, 1]], FALSE) expect_equal(collect(df2)[[3, 2]], TRUE) + # Test that input_file_name() + actual_names <- sort(collect(distinct(select(df, input_file_name())))) + expect_equal(length(actual_names), 1) + expect_equal(basename(actual_names[1, 1]), basename(jsonPath)) + df3 <- select(df, between(df$name, c("Apache", "Spark"))) expect_equal(collect(df3)[[1, 1]], TRUE) expect_equal(collect(df3)[[2, 1]], FALSE) @@ -1391,6 +1450,14 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + # Test map_keys() and map_values() + df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) + result <- collect(select(df, map_keys(df$map)))[[1]] + expect_equal(result, list(list("x", "y"))) + + result <- collect(select(df, map_values(df$map)))[[1]] + expect_equal(result, list(list(1, 2))) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) @@ -1438,17 +1505,27 @@ test_that("column functions", { j <- collect(select(df, alias(to_json(df$people), "json"))) expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- sql("SELECT map('name', 'Bob') as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "{\"name\":\"Bob\"}") + + df <- sql("SELECT array(map('name', 'Bob'), map('name', 'Alice')) as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- read.json(mapTypeJsonPath) j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") df <- as.DataFrame(j) - schema <- structType(structField("age", "integer"), - structField("height", "double")) - s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) - expect_equal(ncol(s), 1) - expect_equal(nrow(s), 3) - expect_is(s[[1]][[1]], "struct") - expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + schemas <- list(structType(structField("age", "integer"), structField("height", "double")), + "age INT, height DOUBLE") + for (schema in schemas) { + s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) + expect_equal(ncol(s), 1) + expect_equal(nrow(s), 3) + expect_is(s[[1]][[1]], "struct") + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + } # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) @@ -1466,14 +1543,15 @@ test_that("column functions", { # check if array type in string is correctly supported. jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) - schema <- structType(structField("name", "string")) - arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) - expect_equal(ncol(arr), 1) - expect_equal(nrow(arr), 1) - expect_is(arr[[1]][[1]], "list") - expect_equal(length(arr$arrcol[[1]]), 2) - expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") - expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + for (schema in list(structType(structField("name", "string")), "name STRING")) { + arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) + expect_equal(ncol(arr), 1) + expect_equal(nrow(arr), 1) + expect_is(arr[[1]][[1]], "list") + expect_equal(length(arr$arrcol[[1]]), 2) + expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") + expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + } # Test create_array() and create_map() df <- as.DataFrame(data.frame( @@ -1497,7 +1575,6 @@ test_that("column functions", { collect(select(df, alias(not(df$is_true), "is_false"))), data.frame(is_false = c(FALSE, TRUE, NA)) ) - }) test_that("column binary mathfunctions", { @@ -1848,7 +1925,11 @@ test_that("test multi-dimensional aggregations with cube and rollup", { orderBy( agg( cube(df, "year", "department"), - expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + expr("sum(salary) AS total_salary"), + expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") ), "year", "department" ) @@ -1875,6 +1956,30 @@ test_that("test multi-dimensional aggregations with cube and rollup", { mean(c(21000, 32000, 22000)), # 2017 22000, 32000, 21000 # 2017 each department ), + grouping_year = c( + 1, # global + 1, 1, 1, # by department + 0, # 2016 + 0, 0, 0, # 2016 by department + 0, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_department = c( + 1, # global + 0, 0, 0, # by department + 1, # 2016 + 0, 0, 0, # 2016 by department + 1, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_id = c( + 3, # 11 + 2, 2, 2, # 10 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), stringsAsFactors = FALSE ) @@ -1896,7 +2001,10 @@ test_that("test multi-dimensional aggregations with cube and rollup", { orderBy( agg( rollup(df, "year", "department"), - expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") ), "year", "department" ) @@ -1920,6 +2028,27 @@ test_that("test multi-dimensional aggregations with cube and rollup", { mean(c(21000, 32000, 22000)), # 2017 22000, 32000, 21000 # 2017 each department ), + grouping_year = c( + 1, # global + 0, # 2016 + 0, 0, 0, # 2016 each department + 0, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_department = c( + 1, # global + 1, # 2016 + 0, 0, 0, # 2016 each department + 1, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_id = c( + 3, # 11 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), stringsAsFactors = FALSE ) @@ -2095,6 +2224,23 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { unlink(jsonPath2) unlink(jsonPath3) + + # Join with broadcast hint + df1 <- sql("SELECT * FROM range(10e10)") + df2 <- sql("SELECT * FROM range(10e10)") + + execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id))) + expect_false(any(grepl("BroadcastHashJoin", execution_plan))) + + execution_plan_hint <- capture.output( + explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) + + execution_plan_broadcast <- capture.output( + explain(join(df1, broadcast(df2), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_broadcast))) }) test_that("toJSON() on DataFrame", { @@ -2131,7 +2277,7 @@ test_that("isLocal()", { expect_false(isLocal(df)) }) -test_that("union(), rbind(), except(), and intersect() on a DataFrame", { +test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataFrame", { df <- read.json(jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", @@ -2147,6 +2293,13 @@ test_that("union(), rbind(), except(), and intersect() on a DataFrame", { expect_equal(first(unioned)$name, "Michael") expect_equal(count(arrange(suppressWarnings(unionAll(df, df2)), df$age)), 6) + df1 <- select(df2, "age", "name") + unioned1 <- arrange(unionByName(df1, df), df1$age) + expect_is(unioned, "SparkDataFrame") + expect_equal(count(unioned), 6) + # Here, we test if 'Michael' in df is correctly mapped to the same name. + expect_equal(first(unioned)$name, "Michael") + unioned2 <- arrange(rbind(unioned, df, df2), df$age) expect_is(unioned2, "SparkDataFrame") expect_equal(count(unioned2), 12) @@ -2290,34 +2443,36 @@ test_that("read/write ORC files - compression option", { }) test_that("read/write Parquet files", { - df <- read.df(jsonPath, "json") - # Test write.df and read.df - write.df(df, parquetPath, "parquet", mode = "overwrite") - df2 <- read.df(parquetPath, "parquet") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.parquet(df, parquetPath2) - parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - suppressWarnings(saveAsParquetFile(df, parquetPath3)) - parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) - expect_is(parquetDF, "SparkDataFrame") - expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) - expect_is(parquetDF2, "SparkDataFrame") - expect_equal(count(parquetDF2), count(df) * 2) - - # Test if varargs works with variables - saveMode <- "overwrite" - mergeSchema <- "true" - parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) - - unlink(parquetPath2) - unlink(parquetPath3) - unlink(parquetPath4) + if (windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + # Test write.df and read.df + write.df(df, parquetPath, "parquet", mode = "overwrite") + df2 <- read.df(parquetPath, "parquet") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) + expect_is(parquetDF, "SparkDataFrame") + expect_equal(count(parquetDF), count(df) * 2) + parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) + expect_is(parquetDF2, "SparkDataFrame") + expect_equal(count(parquetDF2), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) + } }) test_that("read/write Parquet files - compression option/mode", { @@ -2371,7 +2526,7 @@ test_that("read/write text files - compression option", { unlink(textPath) }) -test_that("describe() and summarize() on a DataFrame", { +test_that("describe() and summary() on a DataFrame", { df <- read.json(jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") @@ -2382,9 +2537,16 @@ test_that("describe() and summarize() on a DataFrame", { expect_equal(collect(stats)[5, "age"], "30") stats2 <- summary(df) - expect_equal(collect(stats2)[4, "summary"], "min") + expect_equal(collect(stats2)[5, "summary"], "25%") expect_equal(collect(stats2)[5, "age"], "30") + stats3 <- summary(df, "min", "max", "55.1%") + + expect_equal(collect(stats3)[1, "summary"], "min") + expect_equal(collect(stats3)[2, "summary"], "max") + expect_equal(collect(stats3)[3, "summary"], "55.1%") + expect_equal(collect(stats3)[3, "age"], "30") + # SPARK-16425: SparkR summary() fails on column of type logical df <- withColumn(df, "boolean", df$age == 30) summary(df) @@ -2616,15 +2778,15 @@ test_that("attach() on a DataFrame", { expected_age <- data.frame(age = c(NA, 30, 19)) expect_equal(head(age), expected_age) stat <- summary(age) - expect_equal(collect(stat)[5, "age"], "30") + expect_equal(collect(stat)[8, "age"], "30") age <- age$age + 1 expect_is(age, "Column") rm(age) stat2 <- summary(age) - expect_equal(collect(stat2)[5, "age"], "30") + expect_equal(collect(stat2)[8, "age"], "30") detach("df") stat3 <- summary(df[, "age", drop = F]) - expect_equal(collect(stat3)[5, "age"], "30") + expect_equal(collect(stat3)[8, "age"], "30") expect_error(age) }) @@ -2777,30 +2939,33 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { expect_identical(ldf, result) # Filter and add a column - schema <- structType(structField("a", "integer"), structField("b", "double"), - structField("c", "string"), structField("d", "integer")) - df1 <- dapply( - df, - function(x) { - y <- x[x$a > 1, ] - y <- cbind(y, y$a + 1L) - }, - schema) - result <- collect(df1) - expected <- ldf[ldf$a > 1, ] - expected$d <- expected$a + 1L - rownames(expected) <- NULL - expect_identical(expected, result) - - result <- dapplyCollect( - df, - function(x) { - y <- x[x$a > 1, ] - y <- cbind(y, y$a + 1L) - }) - expected1 <- expected - names(expected1) <- names(result) - expect_identical(expected1, result) + schemas <- list(structType(structField("a", "integer"), structField("b", "double"), + structField("c", "string"), structField("d", "integer")), + "a INT, b DOUBLE, c STRING, d INT") + for (schema in schemas) { + df1 <- dapply( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }, + schema) + result <- collect(df1) + expected <- ldf[ldf$a > 1, ] + expected$d <- expected$a + 1L + rownames(expected) <- NULL + expect_identical(expected, result) + + result <- dapplyCollect( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }) + expected1 <- expected + names(expected1) <- names(result) + expect_identical(expected1, result) + } # Remove the added column df2 <- dapply( @@ -2822,7 +2987,6 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { }) test_that("dapplyCollect() on DataFrame with a binary column", { - df <- data.frame(key = 1:3) df$bytes <- lapply(df$key, serialize, connection = NULL) @@ -2913,29 +3077,32 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { # Computes the sum of second column by grouping on the first and third columns # and checks if the sum is larger than 2 - schema <- structType(structField("a", "integer"), structField("e", "boolean")) - df2 <- gapply( - df, - c(df$"a", df$"c"), - function(key, x) { - y <- data.frame(key[1], sum(x$b) > 2) - }, - schema) - actual <- collect(df2)$e - expected <- c(TRUE, TRUE) - expect_identical(actual, expected) - - df2Collect <- gapplyCollect( - df, - c(df$"a", df$"c"), - function(key, x) { - y <- data.frame(key[1], sum(x$b) > 2) - colnames(y) <- c("a", "e") - y - }) - actual <- df2Collect$e + schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")), + "a INT, e BOOLEAN") + for (schema in schemas) { + df2 <- gapply( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + }, + schema) + actual <- collect(df2)$e + expected <- c(TRUE, TRUE) expect_identical(actual, expected) + df2Collect <- gapplyCollect( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + colnames(y) <- c("a", "e") + y + }) + actual <- df2Collect$e + expect_identical(actual, expected) + } + # Computes the arithmetic mean of the second column by grouping # on the first and third columns. Output the groupping value and the average. schema <- structType(structField("a", "integer"), structField("c", "string"), @@ -3141,9 +3308,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", + paste("Error in load : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) - expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.df("arbitrary_path"), "Error in load : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") @@ -3161,6 +3328,22 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.df with a user defined schema in a DDL-formatted string. + df1 <- read.df(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + expect_error(read.df(jsonPath, "json", "name stri"), "DataType stri is not supported.") + + # Test loadDF with a user defined schema in a DDL-formatted string. + df2 <- loadDF(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + expect_error(loadDF(jsonPath, "json", "name stri"), "DataType stri is not supported.") +}) + test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { ldf <- data.frame(col1 = c(0, 1, 2), col2 = c(as.POSIXct("2017-01-01 00:00:01"), diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R similarity index 82% rename from R/pkg/inst/tests/testthat/test_streaming.R rename to R/pkg/tests/fulltests/test_streaming.R index b125cb0591de..54f40bbd5f51 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -21,10 +21,10 @@ context("Structured Streaming") # Tests for Structured Streaming functions in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsonSubDir <- file.path("sparkr-test", "json", "") -if (.Platform$OS.type == "windows") { +if (is_windows()) { # file.path removes the empty separator on Windows, adds it back jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) } @@ -46,6 +46,8 @@ schema <- structType(structField("name", "string"), structField("age", "integer"), structField("count", "double")) +stringSchema <- "name STRING, age INTEGER, count DOUBLE" + test_that("read.stream, write.stream, awaitTermination, stopQuery", { df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) @@ -53,10 +55,12 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) writeLines(mockLinesNa, jsonPathNa) awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) stopQuery(q) @@ -71,6 +75,7 @@ test_that("print from explain, lastProgress, status, isActive", { q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) @@ -93,6 +98,7 @@ test_that("Stream other format", { q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) expect_equal(queryName(q), "people3") @@ -107,6 +113,27 @@ test_that("Stream other format", { unlink(parquetPath) }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.stream with a user defined schema in a DDL-formatted string. + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_error(read.stream(path = parquetPath, schema = "name stri"), + "DataType stri is not supported.") + + unlink(parquetPath) +}) + test_that("Non-streaming DataFrame", { c <- as.DataFrame(cars) expect_false(isStreaming(c)) diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/tests/fulltests/test_take.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_take.R rename to R/pkg/tests/fulltests/test_take.R index aaa532856c3d..8936cc57da22 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/tests/fulltests/test_take.R @@ -30,7 +30,7 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/tests/fulltests/test_textFile.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_textFile.R rename to R/pkg/tests/fulltests/test_textFile.R index 3b466066e939..be2d2711ff88 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/tests/fulltests/test_textFile.R @@ -18,7 +18,7 @@ context("the textFile() function") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/tests/fulltests/test_utils.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_utils.R rename to R/pkg/tests/fulltests/test_utils.R index 1ca383da26ec..af81423aa8dd 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -18,7 +18,7 @@ context("functions in utils.R") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists @@ -134,7 +134,7 @@ test_that("cleanClosure on R functions", { # Test for broadcast variables. a <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - aBroadcast <- broadcast(sc, a) + aBroadcast <- broadcastRDD(sc, a) normMultiply <- function(x) { norm(aBroadcast$value) * x } newnormMultiply <- SparkR:::cleanClosure(normMultiply) env <- environment(newnormMultiply) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 29812f872c78..a1834a220261 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -21,14 +21,33 @@ library(SparkR) # Turn all warnings into errors options("warn" = 2) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} + # Setup global test environment # Install Spark first to set SPARK_HOME install.spark() sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") -sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) + +sparkRTestMaster <- "local[1]" +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + sparkRTestMaster <- "" +} test_package("SparkR") + +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + # set random seed for predictable results. mostly for base's sample() in tree and classification + set.seed(42) + # for testthat 1.0.2 later, change reporter from "summary" to default_reporter() + testthat:::run_tests("SparkR", + file.path(sparkRDir, "pkg", "tests", "fulltests"), + NULL, + "summary") +} diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 4b9d6c380609..caeae72e37bb 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -27,6 +27,17 @@ vignette: > limitations under the License. --> +```{r setup, include=FALSE} +library(knitr) +opts_hooks$set(eval = function(options) { + # override eval to FALSE only on windows + if (.Platform$OS.type == "windows") { + options$eval = FALSE + } + options +}) +``` + ## Overview SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). @@ -46,8 +57,9 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() +sparkR.session(master = "local[1]") ``` -```{r, message=FALSE, results="hide"} +```{r, eval=FALSE} sparkR.session() ``` @@ -65,7 +77,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun head(carsDF) ``` -Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`. ```{r} carsSubDF <- select(carsDF, "model", "mpg", "hp") carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) @@ -182,7 +194,7 @@ head(df) ``` ### Data Sources -SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. @@ -232,7 +244,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite" ``` ### Hive Tables -You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). ```{r, eval=FALSE} sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -314,7 +326,7 @@ Use `cube` or `rollup` to compute subtotals across multiple dimensions. mean(cube(carsDF, "cyl", "gear", "am"), "mpg") ``` -generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while +generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while ```{r} mean(rollup(carsDF, "cyl", "gear", "am"), "mpg") @@ -379,7 +391,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` -Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} out <- dapplyCollect( @@ -405,7 +417,7 @@ result <- gapply( head(arrange(result, "max_mpg", decreasing = TRUE)) ``` -Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} result <- gapplyCollect( @@ -458,20 +470,20 @@ options(ops) ### SQL Queries -A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. ```{r} people <- read.df(paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/people.json"), "json") ``` -Register this SparkDataFrame as a temporary view. +Register this `SparkDataFrame` as a temporary view. ```{r} createOrReplaceTempView(people, "people") ``` -SQL statements can be run by using the sql method. +SQL statements can be run using the sql method. ```{r} teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) @@ -502,6 +514,8 @@ SparkR supports the following machine learning models and algorithms. #### Tree - Classification and Regression +* Decision Tree + * Gradient-Boosted Trees (GBT) * Random Forest @@ -672,6 +686,7 @@ head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. + ```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) @@ -774,16 +789,32 @@ newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) head(predict(isoregModel, newDF)) ``` +#### Decision Tree + +`spark.decisionTree` fits a [decision tree](https://en.wikipedia.org/wiki/Decision_tree_learning) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +We use the `Titanic` dataset to train a decision tree and make predictions: + +```{r} +t <- as.data.frame(Titanic) +df <- createDataFrame(t) +dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2) +summary(dtModel) +predictions <- predict(dtModel, df) +``` + #### Gradient-Boosted Trees `spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: +We use the `Titanic` dataset to train a gradient-boosted tree and make predictions: -```{r, warning=FALSE} -df <- createDataFrame(longley) -gbtModel <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 2, maxIter = 2) +```{r} +t <- as.data.frame(Titanic) +df <- createDataFrame(t) +gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2) summary(gbtModel) predictions <- predict(gbtModel, df) ``` @@ -793,11 +824,12 @@ predictions <- predict(gbtModel, df) `spark.randomForest` fits a [random forest](https://en.wikipedia.org/wiki/Random_forest) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -In the following example, we use the `longley` dataset to train a random forest and make predictions: +In the following example, we use the `Titanic` dataset to train a random forest and make predictions: -```{r, warning=FALSE} -df <- createDataFrame(longley) -rfModel <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 2, numTrees = 2) +```{r} +t <- as.data.frame(Titanic) +df <- createDataFrame(t) +rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2) summary(rfModel) predictions <- predict(rfModel, df) ``` @@ -819,7 +851,7 @@ head(select(fitted, "Class", "prediction")) `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. -We use a simulated example to demostrate the usage. +We use a simulated example to demonstrate the usage. ```{r} X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) @@ -850,9 +882,9 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 * Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). -* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. +* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. -To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column: * character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. @@ -900,9 +932,9 @@ perplexity `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). -There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. -```{r} +```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(ratings, c("user", "item", "rating")) @@ -910,7 +942,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati ``` Extract latent factors. -```{r} +```{r, eval=FALSE} stats <- summary(model) userFactors <- stats$userFactors itemFactors <- stats$itemFactors @@ -920,7 +952,7 @@ head(itemFactors) Make predictions. -```{r} +```{r, eval=FALSE} predicted <- predict(model, df) head(predicted) ``` @@ -963,24 +995,25 @@ Given a `SparkDataFrame`, the test compares continuous data in a given column `t specified by parameter `nullHypothesis`. Users can call `summary` to get a summary of the test results. -In the following example, we test whether the `longley` dataset's `Armed_Forces` column +In the following example, we test whether the `Titanic` dataset's `Freq` column follows a normal distribution. We set the parameters of the normal distribution using the mean and standard deviation of the sample. -```{r, warning=FALSE} -df <- createDataFrame(longley) -afStats <- head(select(df, mean(df$Armed_Forces), sd(df$Armed_Forces))) -afMean <- afStats[1] -afStd <- afStats[2] +```{r} +t <- as.data.frame(Titanic) +df <- createDataFrame(t) +freqStats <- head(select(df, mean(df$Freq), sd(df$Freq))) +freqMean <- freqStats[1] +freqStd <- freqStats[2] -test <- spark.kstest(df, "Armed_Forces", "norm", c(afMean, afStd)) +test <- spark.kstest(df, "Freq", "norm", c(freqMean, freqStd)) testSummary <- summary(test) testSummary ``` ### Model Persistence -The following example shows how to save/load an ML model by SparkR. +The following example shows how to save/load an ML model in SparkR. ```{r} t <- as.data.frame(Titanic) training <- createDataFrame(t) @@ -1002,6 +1035,72 @@ unlink(modelPath) ``` +## Structured Streaming + +SparkR supports the Structured Streaming API (experimental). + +You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. + +### Simple Source and Sink + +Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts: + +```{r, eval=FALSE} +# Create DataFrame representing the stream of input lines from connection +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") +``` + +### Kafka Source + +It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming. + +```{r, eval=FALSE} +topic <- read.stream("kafka", + kafka.bootstrap.servers = "host1:port1,host2:port2", + subscribe = "topic1") +keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)") +``` + +### Operations and Sinks + +Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`. + +A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats. + +```{r, eval=FALSE} +noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# Aggregate +aggDF <- count(groupBy(noAggDF, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +head(sql("select * from aggregates")) +``` + + ## Advanced Topics ### SparkR Object Classes @@ -1012,19 +1111,19 @@ There are three main object classes in SparkR you may be working with. + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + `env` saves the meta-information of the object such as `isCached`. -It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. -* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. +* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend. -It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. -* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend. -This is often an intermediate object with group information and followed up by aggregation operations. + This is often an intermediate object with group information and followed up by aggregation operations. ### Architecture -A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. +A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*. Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. diff --git a/R/run-tests.sh b/R/run-tests.sh index 742a2c5ed76d..29764f48bd15 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" diff --git a/appveyor.yml b/appveyor.yml index bbb27589cad0..dc2d81fcdc09 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -26,10 +26,13 @@ branches: only_commits: files: + - appveyor.yml + - dev/appveyor-install-dependencies.ps1 - R/ - sql/core/src/main/scala/org/apache/spark/sql/api/r/ - core/src/main/scala/org/apache/spark/api/r/ - mllib/src/main/scala/org/apache/spark/ml/r/ + - core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala cache: - C:\Users\appveyor\.m2 @@ -38,22 +41,20 @@ install: # Install maven and dependencies - ps: .\dev\appveyor-install-dependencies.ps1 # Required package for R unit tests - - cmd: R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('testthat')" - - cmd: R -e "install.packages('e1071', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('e1071')" - - cmd: R -e "install.packages('survival', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('survival')" + - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package +environment: + NOT_CRAN: true + test_script: - - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R + - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R notifications: - provider: Email on_build_success: false on_build_failure: false on_build_status_changed: false - diff --git a/assembly/pom.xml b/assembly/pom.xml index 742a4a1531e7..01fe354235e5 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -187,7 +187,7 @@ org.apache.maven.plugins maven-assembly-plugin - 3.0.0 + 3.1.0 dist @@ -220,11 +220,31 @@ provided + + orc-provided + + provided + + parquet-provided provided + + + + hadoop-cloud + + + org.apache.spark + spark-hadoop-cloud_${scala.binary.version} + ${project.version} + + + diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index 0977025c2036..f946197b02d5 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -36,19 +36,19 @@ if [%SPARK_ENV_LOADED%] == [] ( rem Setting SPARK_SCALA_VERSION if not already set. set ASSEMBLY_DIR2="%SPARK_HOME%\assembly\target\scala-2.11" -set ASSEMBLY_DIR1="%SPARK_HOME%\assembly\target\scala-2.10" +set ASSEMBLY_DIR1="%SPARK_HOME%\assembly\target\scala-2.12" if [%SPARK_SCALA_VERSION%] == [] ( if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( - echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." - echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." + echo "Presence of build for multiple Scala versions detected." + echo "Either clean one of them or, set SPARK_SCALA_VERSION in spark-env.cmd." exit 1 ) if exist %ASSEMBLY_DIR2% ( set SPARK_SCALA_VERSION=2.11 ) else ( - set SPARK_SCALA_VERSION=2.10 + set SPARK_SCALA_VERSION=2.12 ) ) exit /b 0 diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 8a2f709960a2..d05d94e68c81 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -47,17 +47,17 @@ fi if [ -z "$SPARK_SCALA_VERSION" ]; then ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11" - ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.10" + ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.12" if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then - echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 - echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 + echo -e "Presence of build for multiple Scala versions detected." 1>&2 + echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION in spark-env.sh.' 1>&2 exit 1 fi if [ -d "$ASSEMBLY_DIR2" ]; then export SPARK_SCALA_VERSION="2.11" else - export SPARK_SCALA_VERSION="2.10" + export SPARK_SCALA_VERSION="2.12" fi fi diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8..dd286277c1fc 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$1" + exec "$PYSPARK_DRIVER_PYTHON" -m "$@" exit fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index f211c0873ad2..46d4d5c883cf 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/bin/spark-class b/bin/spark-class index 77ea40cc3794..65d3b9612909 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -72,6 +72,8 @@ build_command() { printf "%d\0" $? } +# Turn off posix mode since it does not allow process substitution +set +o posix CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 9faa7d65f83e..a93fd2f0e54b 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -29,7 +29,7 @@ if "x%1"=="x" ( ) rem Find Spark jars. -if exist "%SPARK_HOME%\RELEASE" ( +if exist "%SPARK_HOME%\jars" ( set SPARK_JARS_DIR="%SPARK_HOME%\jars" ) else ( set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%\jars" @@ -51,7 +51,7 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" ( - set RUNNER="%JAVA_HOME%\bin\java" + set RUNNER=%JAVA_HOME%\bin\java ) else ( where /q "%RUNNER%" if ERRORLEVEL 1 ( diff --git a/build/mvn b/build/mvn index 1e393c331dd8..efa4f9364ea5 100755 --- a/build/mvn +++ b/build/mvn @@ -91,13 +91,13 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.11/bin/zinc" + local zinc_path="zinc-0.3.15/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} install_app \ - "${TYPESAFE_MIRROR}/zinc/0.3.11" \ - "zinc-0.3.11.tgz" \ + "${TYPESAFE_MIRROR}/zinc/0.3.15" \ + "zinc-0.3.15.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" } diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml new file mode 100644 index 000000000000..cf93d41cd77c --- /dev/null +++ b/common/kvstore/pom.xml @@ -0,0 +1,106 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../pom.xml + + + spark-kvstore_2.11 + jar + Spark Project Local DB + http://spark.apache.org/ + + kvstore + + + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + com.google.guava + guava + + + org.fusesource.leveldbjni + leveldbjni-all + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + + commons-io + commons-io + test + + + log4j + log4j + test + + + org.slf4j + slf4j-api + test + + + org.slf4j + slf4j-log4j12 + test + + + io.dropwizard.metrics + metrics-core + test + + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/ArrayWrappers.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/ArrayWrappers.java new file mode 100644 index 000000000000..9bc8c55bd538 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/ArrayWrappers.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.util.Arrays; + +import com.google.common.base.Preconditions; + +/** + * A factory for array wrappers so that arrays can be used as keys in a map, sorted or not. + * + * The comparator implementation makes two assumptions: + * - All elements are instances of Comparable + * - When comparing two arrays, they both contain elements of the same type in corresponding + * indices. + * + * Otherwise, ClassCastExceptions may occur. The equality method can compare any two arrays. + * + * This class is not efficient and is mostly meant to compare really small arrays, like those + * generally used as indices and keys in a KVStore. + */ +class ArrayWrappers { + + @SuppressWarnings("unchecked") + public static Comparable forArray(Object a) { + Preconditions.checkArgument(a.getClass().isArray()); + Comparable ret; + if (a instanceof int[]) { + ret = new ComparableIntArray((int[]) a); + } else if (a instanceof long[]) { + ret = new ComparableLongArray((long[]) a); + } else if (a instanceof byte[]) { + ret = new ComparableByteArray((byte[]) a); + } else { + Preconditions.checkArgument(!a.getClass().getComponentType().isPrimitive()); + ret = new ComparableObjectArray((Object[]) a); + } + return (Comparable) ret; + } + + private static class ComparableIntArray implements Comparable { + + private final int[] array; + + ComparableIntArray(int[] array) { + this.array = array; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ComparableIntArray)) { + return false; + } + return Arrays.equals(array, ((ComparableIntArray) other).array); + } + + @Override + public int hashCode() { + int code = 0; + for (int i = 0; i < array.length; i++) { + code = (code * 31) + array[i]; + } + return code; + } + + @Override + public int compareTo(ComparableIntArray other) { + int len = Math.min(array.length, other.array.length); + for (int i = 0; i < len; i++) { + int diff = array[i] - other.array[i]; + if (diff != 0) { + return diff; + } + } + + return array.length - other.array.length; + } + } + + private static class ComparableLongArray implements Comparable { + + private final long[] array; + + ComparableLongArray(long[] array) { + this.array = array; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ComparableLongArray)) { + return false; + } + return Arrays.equals(array, ((ComparableLongArray) other).array); + } + + @Override + public int hashCode() { + int code = 0; + for (int i = 0; i < array.length; i++) { + code = (code * 31) + (int) array[i]; + } + return code; + } + + @Override + public int compareTo(ComparableLongArray other) { + int len = Math.min(array.length, other.array.length); + for (int i = 0; i < len; i++) { + long diff = array[i] - other.array[i]; + if (diff != 0) { + return diff > 0 ? 1 : -1; + } + } + + return array.length - other.array.length; + } + } + + private static class ComparableByteArray implements Comparable { + + private final byte[] array; + + ComparableByteArray(byte[] array) { + this.array = array; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ComparableByteArray)) { + return false; + } + return Arrays.equals(array, ((ComparableByteArray) other).array); + } + + @Override + public int hashCode() { + int code = 0; + for (int i = 0; i < array.length; i++) { + code = (code * 31) + array[i]; + } + return code; + } + + @Override + public int compareTo(ComparableByteArray other) { + int len = Math.min(array.length, other.array.length); + for (int i = 0; i < len; i++) { + int diff = array[i] - other.array[i]; + if (diff != 0) { + return diff; + } + } + + return array.length - other.array.length; + } + } + + private static class ComparableObjectArray implements Comparable { + + private final Object[] array; + + ComparableObjectArray(Object[] array) { + this.array = array; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ComparableObjectArray)) { + return false; + } + return Arrays.equals(array, ((ComparableObjectArray) other).array); + } + + @Override + public int hashCode() { + int code = 0; + for (int i = 0; i < array.length; i++) { + code = (code * 31) + array[i].hashCode(); + } + return code; + } + + @Override + @SuppressWarnings("unchecked") + public int compareTo(ComparableObjectArray other) { + int len = Math.min(array.length, other.array.length); + for (int i = 0; i < len; i++) { + int diff = ((Comparable) array[i]).compareTo((Comparable) other.array[i]); + if (diff != 0) { + return diff; + } + } + + return array.length - other.array.length; + } + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java new file mode 100644 index 000000000000..5ca437128519 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; + +import org.apache.spark.annotation.Private; + +/** + * Implementation of KVStore that keeps data deserialized in memory. This store does not index + * data; instead, whenever iterating over an indexed field, the stored data is copied and sorted + * according to the index. This saves memory but makes iteration more expensive. + */ +@Private +public class InMemoryStore implements KVStore { + + private Object metadata; + private ConcurrentMap, InstanceList> data = new ConcurrentHashMap<>(); + + @Override + public T getMetadata(Class klass) { + return klass.cast(metadata); + } + + @Override + public void setMetadata(Object value) { + this.metadata = value; + } + + @Override + public long count(Class type) { + InstanceList list = data.get(type); + return list != null ? list.size() : 0; + } + + @Override + public long count(Class type, String index, Object indexedValue) throws Exception { + InstanceList list = data.get(type); + int count = 0; + Object comparable = asKey(indexedValue); + KVTypeInfo.Accessor accessor = list.getIndexAccessor(index); + for (Object o : view(type)) { + if (Objects.equal(comparable, asKey(accessor.get(o)))) { + count++; + } + } + return count; + } + + @Override + public T read(Class klass, Object naturalKey) { + InstanceList list = data.get(klass); + Object value = list != null ? list.get(naturalKey) : null; + if (value == null) { + throw new NoSuchElementException(); + } + return klass.cast(value); + } + + @Override + public void write(Object value) throws Exception { + InstanceList list = data.computeIfAbsent(value.getClass(), key -> { + try { + return new InstanceList(key); + } catch (Exception e) { + throw Throwables.propagate(e); + } + }); + list.put(value); + } + + @Override + public void delete(Class type, Object naturalKey) { + InstanceList list = data.get(type); + if (list != null) { + list.delete(naturalKey); + } + } + + @Override + public KVStoreView view(Class type){ + InstanceList list = data.get(type); + return list != null ? list.view(type) + : new InMemoryView<>(type, Collections.emptyList(), null); + } + + @Override + public void close() { + metadata = null; + data.clear(); + } + + @SuppressWarnings("unchecked") + private static Comparable asKey(Object in) { + if (in.getClass().isArray()) { + in = ArrayWrappers.forArray(in); + } + return (Comparable) in; + } + + private static class InstanceList { + + private final KVTypeInfo ti; + private final KVTypeInfo.Accessor naturalKey; + private final ConcurrentMap, Object> data; + + private int size; + + private InstanceList(Class type) throws Exception { + this.ti = new KVTypeInfo(type); + this.naturalKey = ti.getAccessor(KVIndex.NATURAL_INDEX_NAME); + this.data = new ConcurrentHashMap<>(); + this.size = 0; + } + + KVTypeInfo.Accessor getIndexAccessor(String indexName) { + return ti.getAccessor(indexName); + } + + public Object get(Object key) { + return data.get(asKey(key)); + } + + public void put(Object value) throws Exception { + Preconditions.checkArgument(ti.type().equals(value.getClass()), + "Unexpected type: %s", value.getClass()); + if (data.put(asKey(naturalKey.get(value)), value) == null) { + size++; + } + } + + public void delete(Object key) { + if (data.remove(asKey(key)) != null) { + size--; + } + } + + public int size() { + return size; + } + + @SuppressWarnings("unchecked") + public InMemoryView view(Class type) { + Preconditions.checkArgument(ti.type().equals(type), "Unexpected type: %s", type); + Collection all = (Collection) data.values(); + return new InMemoryView<>(type, all, ti); + } + + } + + private static class InMemoryView extends KVStoreView { + + private final Collection elements; + private final KVTypeInfo ti; + private final KVTypeInfo.Accessor natural; + + InMemoryView(Class type, Collection elements, KVTypeInfo ti) { + super(type); + this.elements = elements; + this.ti = ti; + this.natural = ti != null ? ti.getAccessor(KVIndex.NATURAL_INDEX_NAME) : null; + } + + @Override + public Iterator iterator() { + if (elements.isEmpty()) { + return new InMemoryIterator<>(elements.iterator()); + } + + try { + KVTypeInfo.Accessor getter = index != null ? ti.getAccessor(index) : null; + int modifier = ascending ? 1 : -1; + + final List sorted = copyElements(); + Collections.sort(sorted, (e1, e2) -> modifier * compare(e1, e2, getter)); + Stream stream = sorted.stream(); + + if (first != null) { + stream = stream.filter(e -> modifier * compare(e, getter, first) >= 0); + } + + if (last != null) { + stream = stream.filter(e -> modifier * compare(e, getter, last) <= 0); + } + + if (skip > 0) { + stream = stream.skip(skip); + } + + if (max < sorted.size()) { + stream = stream.limit((int) max); + } + + return new InMemoryIterator<>(stream.iterator()); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + /** + * Create a copy of the input elements, filtering the values for child indices if needed. + */ + private List copyElements() { + if (parent != null) { + KVTypeInfo.Accessor parentGetter = ti.getParentAccessor(index); + Preconditions.checkArgument(parentGetter != null, "Parent filter for non-child index."); + + return elements.stream() + .filter(e -> compare(e, parentGetter, parent) == 0) + .collect(Collectors.toList()); + } else { + return new ArrayList<>(elements); + } + } + + private int compare(T e1, T e2, KVTypeInfo.Accessor getter) { + try { + int diff = compare(e1, getter, getter.get(e2)); + if (diff == 0 && getter != natural) { + diff = compare(e1, natural, natural.get(e2)); + } + return diff; + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + private int compare(T e1, KVTypeInfo.Accessor getter, Object v2) { + try { + return asKey(getter.get(e1)).compareTo(asKey(v2)); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + } + + private static class InMemoryIterator implements KVStoreIterator { + + private final Iterator iter; + + InMemoryIterator(Iterator iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public T next() { + return iter.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public List next(int max) { + List list = new ArrayList<>(max); + while (hasNext() && list.size() < max) { + list.add(next()); + } + return list; + } + + @Override + public boolean skip(long n) { + long skipped = 0; + while (skipped < n) { + if (hasNext()) { + next(); + skipped++; + } else { + return false; + } + } + + return hasNext(); + } + + @Override + public void close() { + // no op. + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVIndex.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVIndex.java new file mode 100644 index 000000000000..80f492110724 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVIndex.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.apache.spark.annotation.Private; + +/** + * Tags a field to be indexed when storing an object. + * + *

+ * Types are required to have a natural index that uniquely identifies instances in the store. + * The default value of the annotation identifies the natural index for the type. + *

+ * + *

+ * Indexes allow for more efficient sorting of data read from the store. By annotating a field or + * "getter" method with this annotation, an index will be created that will provide sorting based on + * the string value of that field. + *

+ * + *

+ * Note that creating indices means more space will be needed, and maintenance operations like + * updating or deleting a value will become more expensive. + *

+ * + *

+ * Indices are restricted to String, integral types (byte, short, int, long, boolean), and arrays + * of those values. + *

+ */ +@Private +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.FIELD, ElementType.METHOD}) +public @interface KVIndex { + + String NATURAL_INDEX_NAME = "__main__"; + + /** + * The name of the index to be created for the annotated entity. Must be unique within + * the class. Index names are not allowed to start with an underscore (that's reserved for + * internal use). The default value is the natural index name (which is always a copy index + * regardless of the annotation's values). + */ + String value() default NATURAL_INDEX_NAME; + + /** + * The name of the parent index of this index. By default there is no parent index, so the + * generated data can be retrieved without having to provide a parent value. + * + *

+ * If a parent index is defined, iterating over the data using the index will require providing + * a single value for the parent index. This serves as a rudimentary way to provide relationships + * between entities in the store. + *

+ */ + String parent() default ""; + + /** + * Whether to copy the instance's data to the index, instead of just storing a pointer to the + * data. The default behavior is to just store a reference; that saves disk space but is slower + * to read, since there's a level of indirection. + */ + boolean copy() default false; + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStore.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStore.java new file mode 100644 index 000000000000..72d06a8ca807 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStore.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.io.Closeable; + +import org.apache.spark.annotation.Private; + +/** + * Abstraction for a local key/value store for storing app data. + * + *

+ * There are two main features provided by the implementations of this interface: + *

+ * + *

Serialization

+ * + *

+ * If the underlying data store requires serialization, data will be serialized to and deserialized + * using a {@link KVStoreSerializer}, which can be customized by the application. The serializer is + * based on Jackson, so it supports all the Jackson annotations for controlling the serialization of + * app-defined types. + *

+ * + *

+ * Data is also automatically compressed to save disk space. + *

+ * + *

Automatic Key Management

+ * + *

+ * When using the built-in key management, the implementation will automatically create unique + * keys for each type written to the store. Keys are based on the type name, and always start + * with the "+" prefix character (so that it's easy to use both manual and automatic key + * management APIs without conflicts). + *

+ * + *

+ * Another feature of automatic key management is indexing; by annotating fields or methods of + * objects written to the store with {@link KVIndex}, indices are created to sort the data + * by the values of those properties. This makes it possible to provide sorting without having + * to load all instances of those types from the store. + *

+ * + *

+ * KVStore instances are thread-safe for both reads and writes. + *

+ */ +@Private +public interface KVStore extends Closeable { + + /** + * Returns app-specific metadata from the store, or null if it's not currently set. + * + *

+ * The metadata type is application-specific. This is a convenience method so that applications + * don't need to define their own keys for this information. + *

+ */ + T getMetadata(Class klass) throws Exception; + + /** + * Writes the given value in the store metadata key. + */ + void setMetadata(Object value) throws Exception; + + /** + * Read a specific instance of an object. + * + * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys + * are not allowed. + * @throws java.util.NoSuchElementException If an element with the given key does not exist. + */ + T read(Class klass, Object naturalKey) throws Exception; + + /** + * Writes the given object to the store, including indexed fields. Indices are updated based + * on the annotated fields of the object's class. + * + *

+ * Writes may be slower when the object already exists in the store, since it will involve + * updating existing indices. + *

+ * + * @param value The object to write. + */ + void write(Object value) throws Exception; + + /** + * Removes an object and all data related to it, like index entries, from the store. + * + * @param type The object's type. + * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys + * are not allowed. + * @throws java.util.NoSuchElementException If an element with the given key does not exist. + */ + void delete(Class type, Object naturalKey) throws Exception; + + /** + * Returns a configurable view for iterating over entities of the given type. + */ + KVStoreView view(Class type) throws Exception; + + /** + * Returns the number of items of the given type currently in the store. + */ + long count(Class type) throws Exception; + + /** + * Returns the number of items of the given type which match the given indexed value. + */ + long count(Class type, String index, Object indexedValue) throws Exception; + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreIterator.java new file mode 100644 index 000000000000..e6254a9368ff --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreIterator.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.io.Closeable; +import java.util.Iterator; +import java.util.List; + +import org.apache.spark.annotation.Private; + +/** + * An iterator for KVStore. + * + *

+ * Iterators may keep references to resources that need to be closed. It's recommended that users + * explicitly close iterators after they're used. + *

+ */ +@Private +public interface KVStoreIterator extends Iterator, Closeable { + + /** + * Retrieve multiple elements from the store. + * + * @param max Maximum number of elements to retrieve. + */ + List next(int max); + + /** + * Skip in the iterator. + * + * @return Whether there are items left after skipping. + */ + boolean skip(long n); + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java new file mode 100644 index 000000000000..bd8d9486acde --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.apache.spark.annotation.Private; + +/** + * Serializer used to translate between app-defined types and the LevelDB store. + * + *

+ * The serializer is based on Jackson, so values are written as JSON. It also allows "naked strings" + * and integers to be written as values directly, which will be written as UTF-8 strings. + *

+ */ +@Private +public class KVStoreSerializer { + + /** + * Object mapper used to process app-specific types. If an application requires a specific + * configuration of the mapper, it can subclass this serializer and add custom configuration + * to this object. + */ + protected final ObjectMapper mapper; + + public KVStoreSerializer() { + this.mapper = new ObjectMapper(); + } + + public final byte[] serialize(Object o) throws Exception { + if (o instanceof String) { + return ((String) o).getBytes(UTF_8); + } else { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + GZIPOutputStream out = new GZIPOutputStream(bytes); + try { + mapper.writeValue(out, o); + } finally { + out.close(); + } + return bytes.toByteArray(); + } + } + + @SuppressWarnings("unchecked") + public final T deserialize(byte[] data, Class klass) throws Exception { + if (klass.equals(String.class)) { + return (T) new String(data, UTF_8); + } else { + GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data)); + try { + return mapper.readValue(in, klass); + } finally { + in.close(); + } + } + } + + final byte[] serialize(long value) { + return String.valueOf(value).getBytes(UTF_8); + } + + final long deserializeLong(byte[] data) { + return Long.parseLong(new String(data, UTF_8)); + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreView.java new file mode 100644 index 000000000000..8ea79bbe160d --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreView.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import com.google.common.base.Preconditions; + +import org.apache.spark.annotation.Private; + +/** + * A configurable view that allows iterating over values in a {@link KVStore}. + * + *

+ * The different methods can be used to configure the behavior of the iterator. Calling the same + * method multiple times is allowed; the most recent value will be used. + *

+ * + *

+ * The iterators returned by this view are of type {@link KVStoreIterator}; they auto-close + * when used in a for loop that exhausts their contents, but when used manually, they need + * to be closed explicitly unless all elements are read. + *

+ */ +@Private +public abstract class KVStoreView implements Iterable { + + final Class type; + + boolean ascending = true; + String index = KVIndex.NATURAL_INDEX_NAME; + Object first = null; + Object last = null; + Object parent = null; + long skip = 0L; + long max = Long.MAX_VALUE; + + public KVStoreView(Class type) { + this.type = type; + } + + /** + * Reverses the order of iteration. By default, iterates in ascending order. + */ + public KVStoreView reverse() { + ascending = !ascending; + return this; + } + + /** + * Iterates according to the given index. + */ + public KVStoreView index(String name) { + this.index = Preconditions.checkNotNull(name); + return this; + } + + /** + * Defines the value of the parent index when iterating over a child index. Only elements that + * match the parent index's value will be included in the iteration. + * + *

+ * Required for iterating over child indices, will generate an error if iterating over a + * parent-less index. + *

+ */ + public KVStoreView parent(Object value) { + this.parent = value; + return this; + } + + /** + * Iterates starting at the given value of the chosen index (inclusive). + */ + public KVStoreView first(Object value) { + this.first = value; + return this; + } + + /** + * Stops iteration at the given value of the chosen index (inclusive). + */ + public KVStoreView last(Object value) { + this.last = value; + return this; + } + + /** + * Stops iteration after a number of elements has been retrieved. + */ + public KVStoreView max(long max) { + Preconditions.checkArgument(max > 0L, "max must be positive."); + this.max = max; + return this; + } + + /** + * Skips a number of elements at the start of iteration. Skipped elements are not accounted + * when using {@link #max(long)}. + */ + public KVStoreView skip(long n) { + this.skip = n; + return this; + } + + /** + * Returns an iterator for the current configuration. + */ + public KVStoreIterator closeableIterator() throws Exception { + return (KVStoreIterator) iterator(); + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java new file mode 100644 index 000000000000..a2b077e4531e --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; + +import com.google.common.base.Preconditions; + +import org.apache.spark.annotation.Private; + +/** + * Wrapper around types managed in a KVStore, providing easy access to their indexed fields. + */ +@Private +public class KVTypeInfo { + + private final Class type; + private final Map indices; + private final Map accessors; + + public KVTypeInfo(Class type) throws Exception { + this.type = type; + this.accessors = new HashMap<>(); + this.indices = new HashMap<>(); + + for (Field f : type.getDeclaredFields()) { + KVIndex idx = f.getAnnotation(KVIndex.class); + if (idx != null) { + checkIndex(idx, indices); + indices.put(idx.value(), idx); + f.setAccessible(true); + accessors.put(idx.value(), new FieldAccessor(f)); + } + } + + for (Method m : type.getDeclaredMethods()) { + KVIndex idx = m.getAnnotation(KVIndex.class); + if (idx != null) { + checkIndex(idx, indices); + Preconditions.checkArgument(m.getParameterTypes().length == 0, + "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + indices.put(idx.value(), idx); + m.setAccessible(true); + accessors.put(idx.value(), new MethodAccessor(m)); + } + } + + Preconditions.checkArgument(indices.containsKey(KVIndex.NATURAL_INDEX_NAME), + "No natural index defined for type %s.", type.getName()); + Preconditions.checkArgument(indices.get(KVIndex.NATURAL_INDEX_NAME).parent().isEmpty(), + "Natural index of %s cannot have a parent.", type.getName()); + + for (KVIndex idx : indices.values()) { + if (!idx.parent().isEmpty()) { + KVIndex parent = indices.get(idx.parent()); + Preconditions.checkArgument(parent != null, + "Cannot find parent %s of index %s.", idx.parent(), idx.value()); + Preconditions.checkArgument(parent.parent().isEmpty(), + "Parent index %s of index %s cannot be itself a child index.", idx.parent(), idx.value()); + } + } + } + + private void checkIndex(KVIndex idx, Map indices) { + Preconditions.checkArgument(idx.value() != null && !idx.value().isEmpty(), + "No name provided for index in type %s.", type.getName()); + Preconditions.checkArgument( + !idx.value().startsWith("_") || idx.value().equals(KVIndex.NATURAL_INDEX_NAME), + "Index name %s (in type %s) is not allowed.", idx.value(), type.getName()); + Preconditions.checkArgument(idx.parent().isEmpty() || !idx.parent().equals(idx.value()), + "Index %s cannot be parent of itself.", idx.value()); + Preconditions.checkArgument(!indices.containsKey(idx.value()), + "Duplicate index %s for type %s.", idx.value(), type.getName()); + } + + public Class type() { + return type; + } + + public Object getIndexValue(String indexName, Object instance) throws Exception { + return getAccessor(indexName).get(instance); + } + + public Stream indices() { + return indices.values().stream(); + } + + Accessor getAccessor(String indexName) { + Accessor a = accessors.get(indexName); + Preconditions.checkArgument(a != null, "No index %s.", indexName); + return a; + } + + Accessor getParentAccessor(String indexName) { + KVIndex index = indices.get(indexName); + return index.parent().isEmpty() ? null : getAccessor(index.parent()); + } + + /** + * Abstracts the difference between invoking a Field and a Method. + */ + interface Accessor { + + Object get(Object instance) throws Exception; + + } + + private class FieldAccessor implements Accessor { + + private final Field field; + + FieldAccessor(Field field) { + this.field = field; + } + + @Override + public Object get(Object instance) throws Exception { + return field.get(instance); + } + + } + + private class MethodAccessor implements Accessor { + + private final Method method; + + MethodAccessor(Method method) { + this.method = method; + } + + @Override + public Object get(Object instance) throws Exception { + return method.invoke(instance); + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java new file mode 100644 index 000000000000..ff48b155fab3 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.fusesource.leveldbjni.JniDBFactory; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.Options; +import org.iq80.leveldb.WriteBatch; + +import org.apache.spark.annotation.Private; + +/** + * Implementation of KVStore that uses LevelDB as the underlying data store. + */ +@Private +public class LevelDB implements KVStore { + + @VisibleForTesting + static final long STORE_VERSION = 1L; + + @VisibleForTesting + static final byte[] STORE_VERSION_KEY = "__version__".getBytes(UTF_8); + + /** DB key where app metadata is stored. */ + private static final byte[] METADATA_KEY = "__meta__".getBytes(UTF_8); + + /** DB key where type aliases are stored. */ + private static final byte[] TYPE_ALIASES_KEY = "__types__".getBytes(UTF_8); + + final AtomicReference _db; + final KVStoreSerializer serializer; + + /** + * Keep a mapping of class names to a shorter, unique ID managed by the store. This serves two + * purposes: make the keys stored on disk shorter, and spread out the keys, since class names + * will often have a long, redundant prefix (think "org.apache.spark."). + */ + private final ConcurrentMap typeAliases; + private final ConcurrentMap, LevelDBTypeInfo> types; + + public LevelDB(File path) throws Exception { + this(path, new KVStoreSerializer()); + } + + public LevelDB(File path, KVStoreSerializer serializer) throws Exception { + this.serializer = serializer; + this.types = new ConcurrentHashMap<>(); + + Options options = new Options(); + options.createIfMissing(!path.exists()); + this._db = new AtomicReference<>(JniDBFactory.factory.open(path, options)); + + byte[] versionData = db().get(STORE_VERSION_KEY); + if (versionData != null) { + long version = serializer.deserializeLong(versionData); + if (version != STORE_VERSION) { + throw new UnsupportedStoreVersionException(); + } + } else { + db().put(STORE_VERSION_KEY, serializer.serialize(STORE_VERSION)); + } + + Map aliases; + try { + aliases = get(TYPE_ALIASES_KEY, TypeAliases.class).aliases; + } catch (NoSuchElementException e) { + aliases = new HashMap<>(); + } + typeAliases = new ConcurrentHashMap<>(aliases); + } + + @Override + public T getMetadata(Class klass) throws Exception { + try { + return get(METADATA_KEY, klass); + } catch (NoSuchElementException nsee) { + return null; + } + } + + @Override + public void setMetadata(Object value) throws Exception { + if (value != null) { + put(METADATA_KEY, value); + } else { + db().delete(METADATA_KEY); + } + } + + T get(byte[] key, Class klass) throws Exception { + byte[] data = db().get(key); + if (data == null) { + throw new NoSuchElementException(new String(key, UTF_8)); + } + return serializer.deserialize(data, klass); + } + + private void put(byte[] key, Object value) throws Exception { + Preconditions.checkArgument(value != null, "Null values are not allowed."); + db().put(key, serializer.serialize(value)); + } + + @Override + public T read(Class klass, Object naturalKey) throws Exception { + Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); + byte[] key = getTypeInfo(klass).naturalIndex().start(null, naturalKey); + return get(key, klass); + } + + @Override + public void write(Object value) throws Exception { + Preconditions.checkArgument(value != null, "Null values are not allowed."); + LevelDBTypeInfo ti = getTypeInfo(value.getClass()); + + try (WriteBatch batch = db().createWriteBatch()) { + byte[] data = serializer.serialize(value); + synchronized (ti) { + Object existing; + try { + existing = get(ti.naturalIndex().entityKey(null, value), value.getClass()); + } catch (NoSuchElementException e) { + existing = null; + } + + PrefixCache cache = new PrefixCache(value); + byte[] naturalKey = ti.naturalIndex().toKey(ti.naturalIndex().getValue(value)); + for (LevelDBTypeInfo.Index idx : ti.indices()) { + byte[] prefix = cache.getPrefix(idx); + idx.add(batch, value, existing, data, naturalKey, prefix); + } + db().write(batch); + } + } + } + + @Override + public void delete(Class type, Object naturalKey) throws Exception { + Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); + try (WriteBatch batch = db().createWriteBatch()) { + LevelDBTypeInfo ti = getTypeInfo(type); + byte[] key = ti.naturalIndex().start(null, naturalKey); + synchronized (ti) { + byte[] data = db().get(key); + if (data != null) { + Object existing = serializer.deserialize(data, type); + PrefixCache cache = new PrefixCache(existing); + byte[] keyBytes = ti.naturalIndex().toKey(ti.naturalIndex().getValue(existing)); + for (LevelDBTypeInfo.Index idx : ti.indices()) { + idx.remove(batch, existing, keyBytes, cache.getPrefix(idx)); + } + db().write(batch); + } + } + } catch (NoSuchElementException nse) { + // Ignore. + } + } + + @Override + public KVStoreView view(Class type) throws Exception { + return new KVStoreView(type) { + @Override + public Iterator iterator() { + try { + return new LevelDBIterator<>(LevelDB.this, this); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + }; + } + + @Override + public long count(Class type) throws Exception { + LevelDBTypeInfo.Index idx = getTypeInfo(type).naturalIndex(); + return idx.getCount(idx.end(null)); + } + + @Override + public long count(Class type, String index, Object indexedValue) throws Exception { + LevelDBTypeInfo.Index idx = getTypeInfo(type).index(index); + return idx.getCount(idx.end(null, indexedValue)); + } + + @Override + public void close() throws IOException { + synchronized (this._db) { + DB _db = this._db.getAndSet(null); + if (_db == null) { + return; + } + + try { + _db.close(); + } catch (IOException ioe) { + throw ioe; + } catch (Exception e) { + throw new IOException(e.getMessage(), e); + } + } + } + + /** + * Closes the given iterator if the DB is still open. Trying to close a JNI LevelDB handle + * with a closed DB can cause JVM crashes, so this ensures that situation does not happen. + */ + void closeIterator(LevelDBIterator it) throws IOException { + synchronized (this._db) { + DB _db = this._db.get(); + if (_db != null) { + it.close(); + } + } + } + + /** Returns metadata about indices for the given type. */ + LevelDBTypeInfo getTypeInfo(Class type) throws Exception { + LevelDBTypeInfo ti = types.get(type); + if (ti == null) { + LevelDBTypeInfo tmp = new LevelDBTypeInfo(this, type, getTypeAlias(type)); + ti = types.putIfAbsent(type, tmp); + if (ti == null) { + ti = tmp; + } + } + return ti; + } + + /** + * Try to avoid use-after close since that has the tendency of crashing the JVM. This doesn't + * prevent methods that retrieved the instance from using it after close, but hopefully will + * catch most cases; otherwise, we'll need some kind of locking. + */ + DB db() { + DB _db = this._db.get(); + if (_db == null) { + throw new IllegalStateException("DB is closed."); + } + return _db; + } + + private byte[] getTypeAlias(Class klass) throws Exception { + byte[] alias = typeAliases.get(klass.getName()); + if (alias == null) { + synchronized (typeAliases) { + byte[] tmp = String.valueOf(typeAliases.size()).getBytes(UTF_8); + alias = typeAliases.putIfAbsent(klass.getName(), tmp); + if (alias == null) { + alias = tmp; + put(TYPE_ALIASES_KEY, new TypeAliases(typeAliases)); + } + } + } + return alias; + } + + /** Needs to be public for Jackson. */ + public static class TypeAliases { + + public Map aliases; + + TypeAliases(Map aliases) { + this.aliases = aliases; + } + + TypeAliases() { + this(null); + } + + } + + private static class PrefixCache { + + private final Object entity; + private final Map prefixes; + + PrefixCache(Object entity) { + this.entity = entity; + this.prefixes = new HashMap<>(); + } + + byte[] getPrefix(LevelDBTypeInfo.Index idx) throws Exception { + byte[] prefix = null; + if (idx.isChild()) { + prefix = prefixes.get(idx.parent()); + if (prefix == null) { + prefix = idx.parent().childPrefix(idx.parent().getValue(entity)); + prefixes.put(idx.parent(), prefix); + } + } + return prefix; + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java new file mode 100644 index 000000000000..b3ba76ba5805 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.iq80.leveldb.DBIterator; + +class LevelDBIterator implements KVStoreIterator { + + private final LevelDB db; + private final boolean ascending; + private final DBIterator it; + private final Class type; + private final LevelDBTypeInfo ti; + private final LevelDBTypeInfo.Index index; + private final byte[] indexKeyPrefix; + private final byte[] end; + private final long max; + + private boolean checkedNext; + private byte[] next; + private boolean closed; + private long count; + + LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { + this.db = db; + this.ascending = params.ascending; + this.it = db.db().iterator(); + this.type = params.type; + this.ti = db.getTypeInfo(type); + this.index = ti.index(params.index); + this.max = params.max; + + Preconditions.checkArgument(!index.isChild() || params.parent != null, + "Cannot iterate over child index %s without parent value.", params.index); + byte[] parent = index.isChild() ? index.parent().childPrefix(params.parent) : null; + + this.indexKeyPrefix = index.keyPrefix(parent); + + byte[] firstKey; + if (params.first != null) { + if (ascending) { + firstKey = index.start(parent, params.first); + } else { + firstKey = index.end(parent, params.first); + } + } else if (ascending) { + firstKey = index.keyPrefix(parent); + } else { + firstKey = index.end(parent); + } + it.seek(firstKey); + + byte[] end = null; + if (ascending) { + if (params.last != null) { + end = index.end(parent, params.last); + } else { + end = index.end(parent); + } + } else { + if (params.last != null) { + end = index.start(parent, params.last); + } + if (it.hasNext()) { + // When descending, the caller may have set up the start of iteration at a non-existant + // entry that is guaranteed to be after the desired entry. For example, if you have a + // compound key (a, b) where b is a, integer, you may seek to the end of the elements that + // have the same "a" value by specifying Integer.MAX_VALUE for "b", and that value may not + // exist in the database. So need to check here whether the next value actually belongs to + // the set being returned by the iterator before advancing. + byte[] nextKey = it.peekNext().getKey(); + if (compare(nextKey, indexKeyPrefix) <= 0) { + it.next(); + } + } + } + this.end = end; + + if (params.skip > 0) { + skip(params.skip); + } + } + + @Override + public boolean hasNext() { + if (!checkedNext && !closed) { + next = loadNext(); + checkedNext = true; + } + if (!closed && next == null) { + try { + close(); + } catch (IOException ioe) { + throw Throwables.propagate(ioe); + } + } + return next != null; + } + + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + checkedNext = false; + + try { + T ret; + if (index == null || index.isCopy()) { + ret = db.serializer.deserialize(next, type); + } else { + byte[] key = ti.buildKey(false, ti.naturalIndex().keyPrefix(null), next); + ret = db.get(key, type); + } + next = null; + return ret; + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public List next(int max) { + List list = new ArrayList<>(max); + while (hasNext() && list.size() < max) { + list.add(next()); + } + return list; + } + + @Override + public boolean skip(long n) { + long skipped = 0; + while (skipped < n) { + if (next != null) { + checkedNext = false; + next = null; + skipped++; + continue; + } + + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + checkedNext = true; + return false; + } + + Map.Entry e = ascending ? it.next() : it.prev(); + if (!isEndMarker(e.getKey())) { + skipped++; + } + } + + return hasNext(); + } + + @Override + public synchronized void close() throws IOException { + if (!closed) { + it.close(); + closed = true; + } + } + + /** + * Because it's tricky to expose closeable iterators through many internal APIs, especially + * when Scala wrappers are used, this makes sure that, hopefully, the JNI resources held by + * the iterator will eventually be released. + */ + @Override + protected void finalize() throws Throwable { + db.closeIterator(this); + } + + private byte[] loadNext() { + if (count >= max) { + return null; + } + + try { + while (true) { + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + return null; + } + + Map.Entry nextEntry; + try { + // Avoid races if another thread is updating the DB. + nextEntry = ascending ? it.next() : it.prev(); + } catch (NoSuchElementException e) { + return null; + } + + byte[] nextKey = nextEntry.getKey(); + // Next key is not part of the index, stop. + if (!startsWith(nextKey, indexKeyPrefix)) { + return null; + } + + // If the next key is an end marker, then skip it. + if (isEndMarker(nextKey)) { + continue; + } + + // If there's a known end key and iteration has gone past it, stop. + if (end != null) { + int comp = compare(nextKey, end) * (ascending ? 1 : -1); + if (comp > 0) { + return null; + } + } + + count++; + + // Next element is part of the iteration, return it. + return nextEntry.getValue(); + } + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @VisibleForTesting + static boolean startsWith(byte[] key, byte[] prefix) { + if (key.length < prefix.length) { + return false; + } + + for (int i = 0; i < prefix.length; i++) { + if (key[i] != prefix[i]) { + return false; + } + } + + return true; + } + + private boolean isEndMarker(byte[] key) { + return (key.length > 2 && + key[key.length - 2] == LevelDBTypeInfo.KEY_SEPARATOR && + key[key.length - 1] == LevelDBTypeInfo.END_MARKER[0]); + } + + static int compare(byte[] a, byte[] b) { + int diff = 0; + int minLen = Math.min(a.length, b.length); + for (int i = 0; i < minLen; i++) { + diff += (a[i] - b[i]); + if (diff != 0) { + return diff; + } + } + + return a.length - b.length; + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java new file mode 100644 index 000000000000..232ee41dd0b1 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.lang.reflect.Array; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.Preconditions; +import org.iq80.leveldb.WriteBatch; + +/** + * Holds metadata about app-specific types stored in LevelDB. Serves as a cache for data collected + * via reflection, to make it cheaper to access it multiple times. + * + *

+ * The hierarchy of keys stored in LevelDB looks roughly like the following. This hierarchy ensures + * that iteration over indices is easy, and that updating values in the store is not overly + * expensive. Of note, indices choose using more disk space (one value per key) instead of keeping + * lists of pointers, which would be more expensive to update at runtime. + *

+ * + *

+ * Indentation defines when a sub-key lives under a parent key. In LevelDB, this means the full + * key would be the concatenation of everything up to that point in the hierarchy, with each + * component separated by a NULL byte. + *

+ * + *
+ * +TYPE_NAME
+ *   NATURAL_INDEX
+ *     +NATURAL_KEY
+ *     -
+ *   -NATURAL_INDEX
+ *   INDEX_NAME
+ *     +INDEX_VALUE
+ *       +NATURAL_KEY
+ *     -INDEX_VALUE
+ *     .INDEX_VALUE
+ *       CHILD_INDEX_NAME
+ *         +CHILD_INDEX_VALUE
+ *           NATURAL_KEY_OR_DATA
+ *         -
+ *   -INDEX_NAME
+ * 
+ * + *

+ * Entity data (either the entity's natural key or a copy of the data) is stored in all keys + * that end with "+". A count of all objects that match a particular top-level index + * value is kept at the end marker ("-"). A count is also kept at the natural index's end + * marker, to make it easy to retrieve the number of all elements of a particular type. + *

+ * + *

+ * To illustrate, given a type "Foo", with a natural index and a second index called "bar", you'd + * have these keys and values in the store for two instances, one with natural key "key1" and the + * other "key2", both with value "yes" for "bar": + *

+ * + *
+ * Foo __main__ +key1   [data for instance 1]
+ * Foo __main__ +key2   [data for instance 2]
+ * Foo __main__ -       [count of all Foo]
+ * Foo bar +yes +key1   [instance 1 key or data, depending on index type]
+ * Foo bar +yes +key2   [instance 2 key or data, depending on index type]
+ * Foo bar +yes -       [count of all Foo with "bar=yes" ]
+ * 
+ * + *

+ * Note that all indexed values are prepended with "+", even if the index itself does not have an + * explicit end marker. This allows for easily skipping to the end of an index by telling LevelDB + * to seek to the "phantom" end marker of the index. Throughout the code and comments, this part + * of the full LevelDB key is generally referred to as the "index value" of the entity. + *

+ * + *

+ * Child indices are stored after their parent index. In the example above, let's assume there is + * a child index "child", whose parent is "bar". If both instances have value "no" for this field, + * the data in the store would look something like the following: + *

+ * + *
+ * ...
+ * Foo bar +yes -
+ * Foo bar .yes .child +no +key1   [instance 1 key or data, depending on index type]
+ * Foo bar .yes .child +no +key2   [instance 2 key or data, depending on index type]
+ * ...
+ * 
+ */ +class LevelDBTypeInfo { + + static final byte[] END_MARKER = new byte[] { '-' }; + static final byte ENTRY_PREFIX = (byte) '+'; + static final byte KEY_SEPARATOR = 0x0; + static byte TRUE = (byte) '1'; + static byte FALSE = (byte) '0'; + + private static final byte SECONDARY_IDX_PREFIX = (byte) '.'; + private static final byte POSITIVE_MARKER = (byte) '='; + private static final byte NEGATIVE_MARKER = (byte) '*'; + private static final byte[] HEX_BYTES = new byte[] { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' + }; + + private final LevelDB db; + private final Class type; + private final Map indices; + private final byte[] typePrefix; + + LevelDBTypeInfo(LevelDB db, Class type, byte[] alias) throws Exception { + this.db = db; + this.type = type; + this.indices = new HashMap<>(); + + KVTypeInfo ti = new KVTypeInfo(type); + + // First create the parent indices, then the child indices. + ti.indices().forEach(idx -> { + if (idx.parent().isEmpty()) { + indices.put(idx.value(), new Index(idx, ti.getAccessor(idx.value()), null)); + } + }); + ti.indices().forEach(idx -> { + if (!idx.parent().isEmpty()) { + indices.put(idx.value(), new Index(idx, ti.getAccessor(idx.value()), + indices.get(idx.parent()))); + } + }); + + this.typePrefix = alias; + } + + Class type() { + return type; + } + + byte[] keyPrefix() { + return typePrefix; + } + + Index naturalIndex() { + return index(KVIndex.NATURAL_INDEX_NAME); + } + + Index index(String name) { + Index i = indices.get(name); + Preconditions.checkArgument(i != null, "Index %s does not exist for type %s.", name, + type.getName()); + return i; + } + + Collection indices() { + return indices.values(); + } + + byte[] buildKey(byte[]... components) { + return buildKey(true, components); + } + + byte[] buildKey(boolean addTypePrefix, byte[]... components) { + int len = 0; + if (addTypePrefix) { + len += typePrefix.length + 1; + } + for (byte[] comp : components) { + len += comp.length; + } + len += components.length - 1; + + byte[] dest = new byte[len]; + int written = 0; + + if (addTypePrefix) { + System.arraycopy(typePrefix, 0, dest, 0, typePrefix.length); + dest[typePrefix.length] = KEY_SEPARATOR; + written += typePrefix.length + 1; + } + + for (byte[] comp : components) { + System.arraycopy(comp, 0, dest, written, comp.length); + written += comp.length; + if (written < dest.length) { + dest[written] = KEY_SEPARATOR; + written++; + } + } + + return dest; + } + + /** + * Models a single index in LevelDB. See top-level class's javadoc for a description of how the + * keys are generated. + */ + class Index { + + private final boolean copy; + private final boolean isNatural; + private final byte[] name; + private final KVTypeInfo.Accessor accessor; + private final Index parent; + + private Index(KVIndex self, KVTypeInfo.Accessor accessor, Index parent) { + byte[] name = self.value().getBytes(UTF_8); + if (parent != null) { + byte[] child = new byte[name.length + 1]; + child[0] = SECONDARY_IDX_PREFIX; + System.arraycopy(name, 0, child, 1, name.length); + } + + this.name = name; + this.isNatural = self.value().equals(KVIndex.NATURAL_INDEX_NAME); + this.copy = isNatural || self.copy(); + this.accessor = accessor; + this.parent = parent; + } + + boolean isCopy() { + return copy; + } + + boolean isChild() { + return parent != null; + } + + Index parent() { + return parent; + } + + /** + * Creates a key prefix for child indices of this index. This allows the prefix to be + * calculated only once, avoiding redundant work when multiple child indices of the + * same parent index exist. + */ + byte[] childPrefix(Object value) { + Preconditions.checkState(parent == null, "Not a parent index."); + return buildKey(name, toParentKey(value)); + } + + /** + * Gets the index value for a particular entity (which is the value of the field or method + * tagged with the index annotation). This is used as part of the LevelDB key where the + * entity (or its id) is stored. + */ + Object getValue(Object entity) throws Exception { + return accessor.get(entity); + } + + private void checkParent(byte[] prefix) { + if (prefix != null) { + Preconditions.checkState(parent != null, "Parent prefix provided for parent index."); + } else { + Preconditions.checkState(parent == null, "Parent prefix missing for child index."); + } + } + + /** The prefix for all keys that belong to this index. */ + byte[] keyPrefix(byte[] prefix) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name) : buildKey(name); + } + + /** + * The key where to start ascending iteration for entities whose value for the indexed field + * match the given value. + */ + byte[] start(byte[] prefix, Object value) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, toKey(value)) + : buildKey(name, toKey(value)); + } + + /** The key for the index's end marker. */ + byte[] end(byte[] prefix) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, END_MARKER) + : buildKey(name, END_MARKER); + } + + /** The key for the end marker for entries with the given value. */ + byte[] end(byte[] prefix, Object value) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, toKey(value), END_MARKER) + : buildKey(name, toKey(value), END_MARKER); + } + + /** The full key in the index that identifies the given entity. */ + byte[] entityKey(byte[] prefix, Object entity) throws Exception { + Object indexValue = getValue(entity); + Preconditions.checkNotNull(indexValue, "Null index value for %s in type %s.", + name, type.getName()); + byte[] entityKey = start(prefix, indexValue); + if (!isNatural) { + entityKey = buildKey(false, entityKey, toKey(naturalIndex().getValue(entity))); + } + return entityKey; + } + + private void updateCount(WriteBatch batch, byte[] key, long delta) { + long updated = getCount(key) + delta; + if (updated > 0) { + batch.put(key, db.serializer.serialize(updated)); + } else { + batch.delete(key); + } + } + + private void addOrRemove( + WriteBatch batch, + Object entity, + Object existing, + byte[] data, + byte[] naturalKey, + byte[] prefix) throws Exception { + Object indexValue = getValue(entity); + Preconditions.checkNotNull(indexValue, "Null index value for %s in type %s.", + name, type.getName()); + + byte[] entityKey = start(prefix, indexValue); + if (!isNatural) { + entityKey = buildKey(false, entityKey, naturalKey); + } + + boolean needCountUpdate = (existing == null); + + // Check whether there's a need to update the index. The index needs to be updated in two + // cases: + // + // - There is no existing value for the entity, so a new index value will be added. + // - If there is a previously stored value for the entity, and the index value for the + // current index does not match the new value, the old entry needs to be deleted and + // the new one added. + // + // Natural indices don't need to be checked, because by definition both old and new entities + // will have the same key. The put() call is all that's needed in that case. + // + // Also check whether we need to update the counts. If the indexed value is changing, we + // need to decrement the count at the old index value, and the new indexed value count needs + // to be incremented. + if (existing != null && !isNatural) { + byte[] oldPrefix = null; + Object oldIndexedValue = getValue(existing); + boolean removeExisting = !indexValue.equals(oldIndexedValue); + if (!removeExisting && isChild()) { + oldPrefix = parent().childPrefix(parent().getValue(existing)); + removeExisting = LevelDBIterator.compare(prefix, oldPrefix) != 0; + } + + if (removeExisting) { + if (oldPrefix == null && isChild()) { + oldPrefix = parent().childPrefix(parent().getValue(existing)); + } + + byte[] oldKey = entityKey(oldPrefix, existing); + batch.delete(oldKey); + + // If the indexed value has changed, we need to update the counts at the old and new + // end markers for the indexed value. + if (!isChild()) { + byte[] oldCountKey = end(null, oldIndexedValue); + updateCount(batch, oldCountKey, -1L); + needCountUpdate = true; + } + } + } + + if (data != null) { + byte[] stored = copy ? data : naturalKey; + batch.put(entityKey, stored); + } else { + batch.delete(entityKey); + } + + if (needCountUpdate && !isChild()) { + long delta = data != null ? 1L : -1L; + byte[] countKey = isNatural ? end(prefix) : end(prefix, indexValue); + updateCount(batch, countKey, delta); + } + } + + /** + * Add an entry to the index. + * + * @param batch Write batch with other related changes. + * @param entity The entity being added to the index. + * @param existing The entity being replaced in the index, or null. + * @param data Serialized entity to store (when storing the entity, not a reference). + * @param naturalKey The value's natural key (to avoid re-computing it for every index). + * @param prefix The parent index prefix, if this is a child index. + */ + void add( + WriteBatch batch, + Object entity, + Object existing, + byte[] data, + byte[] naturalKey, + byte[] prefix) throws Exception { + addOrRemove(batch, entity, existing, data, naturalKey, prefix); + } + + /** + * Remove a value from the index. + * + * @param batch Write batch with other related changes. + * @param entity The entity being removed, to identify the index entry to modify. + * @param naturalKey The value's natural key (to avoid re-computing it for every index). + * @param prefix The parent index prefix, if this is a child index. + */ + void remove( + WriteBatch batch, + Object entity, + byte[] naturalKey, + byte[] prefix) throws Exception { + addOrRemove(batch, entity, null, null, naturalKey, prefix); + } + + long getCount(byte[] key) { + byte[] data = db.db().get(key); + return data != null ? db.serializer.deserializeLong(data) : 0; + } + + byte[] toParentKey(Object value) { + return toKey(value, SECONDARY_IDX_PREFIX); + } + + byte[] toKey(Object value) { + return toKey(value, ENTRY_PREFIX); + } + + /** + * Translates a value to be used as part of the store key. + * + * Integral numbers are encoded as a string in a way that preserves lexicographical + * ordering. The string is prepended with a marker telling whether the number is negative + * or positive ("*" for negative and "=" for positive are used since "-" and "+" have the + * opposite of the desired order), and then the number is encoded into a hex string (so + * it occupies twice the number of bytes as the original type). + * + * Arrays are encoded by encoding each element separately, separated by KEY_SEPARATOR. + */ + byte[] toKey(Object value, byte prefix) { + final byte[] result; + + if (value instanceof String) { + byte[] str = ((String) value).getBytes(UTF_8); + result = new byte[str.length + 1]; + result[0] = prefix; + System.arraycopy(str, 0, result, 1, str.length); + } else if (value instanceof Boolean) { + result = new byte[] { prefix, (Boolean) value ? TRUE : FALSE }; + } else if (value.getClass().isArray()) { + int length = Array.getLength(value); + byte[][] components = new byte[length][]; + for (int i = 0; i < length; i++) { + components[i] = toKey(Array.get(value, i)); + } + result = buildKey(false, components); + } else { + int bytes; + + if (value instanceof Integer) { + bytes = Integer.SIZE; + } else if (value instanceof Long) { + bytes = Long.SIZE; + } else if (value instanceof Short) { + bytes = Short.SIZE; + } else if (value instanceof Byte) { + bytes = Byte.SIZE; + } else { + throw new IllegalArgumentException(String.format("Type %s not allowed as key.", + value.getClass().getName())); + } + + bytes = bytes / Byte.SIZE; + + byte[] key = new byte[bytes * 2 + 2]; + long longValue = ((Number) value).longValue(); + key[0] = prefix; + key[1] = longValue > 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; + + for (int i = 0; i < key.length - 2; i++) { + int masked = (int) ((longValue >>> (4 * i)) & 0xF); + key[key.length - i - 1] = HEX_BYTES[masked]; + } + + result = key; + } + + return result; + } + + } + +} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/UnsupportedStoreVersionException.java similarity index 68% rename from repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala rename to common/kvstore/src/main/java/org/apache/spark/util/kvstore/UnsupportedStoreVersionException.java index 94c801ebec7c..75a33b7c75de 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/UnsupportedStoreVersionException.java @@ -15,17 +15,16 @@ * limitations under the License. */ -package org.apache.spark.repl +package org.apache.spark.util.kvstore; -import scala.tools.nsc.Settings +import java.io.IOException; + +import org.apache.spark.annotation.Private; /** - * scala.tools.nsc.Settings implementation adding Spark-specific REPL - * command line options. + * Exception thrown when the store implementation is not compatible with the underlying data. */ -private[repl] class SparkRunnerSettings(error: String => Unit) extends Settings(error) { - val loadfiles = MultiStringSetting( - "-i", - "file", - "load a file (assumes the code is given interactively)") +@Private +public class UnsupportedStoreVersionException extends IOException { + } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayKeyIndexType.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayKeyIndexType.java new file mode 100644 index 000000000000..32030fb4115c --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayKeyIndexType.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.util.Arrays; + +public class ArrayKeyIndexType { + + @KVIndex + public int[] key; + + @KVIndex("id") + public String[] id; + + @Override + public boolean equals(Object o) { + if (o instanceof ArrayKeyIndexType) { + ArrayKeyIndexType other = (ArrayKeyIndexType) o; + return Arrays.equals(key, other.key) && Arrays.equals(id, other.id); + } + return false; + } + + @Override + public int hashCode() { + return key.hashCode(); + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayWrappersSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayWrappersSuite.java new file mode 100644 index 000000000000..b1c8927d0761 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayWrappersSuite.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class ArrayWrappersSuite { + + @Test + public void testGenericArrayKey() { + byte[] b1 = new byte[] { 0x01, 0x02, 0x03 }; + byte[] b2 = new byte[] { 0x01, 0x02 }; + int[] i1 = new int[] { 1, 2, 3 }; + int[] i2 = new int[] { 1, 2 }; + String[] s1 = new String[] { "1", "2", "3" }; + String[] s2 = new String[] { "1", "2" }; + + assertEquals(ArrayWrappers.forArray(b1), ArrayWrappers.forArray(b1)); + assertNotEquals(ArrayWrappers.forArray(b1), ArrayWrappers.forArray(b2)); + assertNotEquals(ArrayWrappers.forArray(b1), ArrayWrappers.forArray(i1)); + assertNotEquals(ArrayWrappers.forArray(b1), ArrayWrappers.forArray(s1)); + + assertEquals(ArrayWrappers.forArray(i1), ArrayWrappers.forArray(i1)); + assertNotEquals(ArrayWrappers.forArray(i1), ArrayWrappers.forArray(i2)); + assertNotEquals(ArrayWrappers.forArray(i1), ArrayWrappers.forArray(b1)); + assertNotEquals(ArrayWrappers.forArray(i1), ArrayWrappers.forArray(s1)); + + assertEquals(ArrayWrappers.forArray(s1), ArrayWrappers.forArray(s1)); + assertNotEquals(ArrayWrappers.forArray(s1), ArrayWrappers.forArray(s2)); + assertNotEquals(ArrayWrappers.forArray(s1), ArrayWrappers.forArray(b1)); + assertNotEquals(ArrayWrappers.forArray(s1), ArrayWrappers.forArray(i1)); + + assertEquals(0, ArrayWrappers.forArray(b1).compareTo(ArrayWrappers.forArray(b1))); + assertTrue(ArrayWrappers.forArray(b1).compareTo(ArrayWrappers.forArray(b2)) > 0); + + assertEquals(0, ArrayWrappers.forArray(i1).compareTo(ArrayWrappers.forArray(i1))); + assertTrue(ArrayWrappers.forArray(i1).compareTo(ArrayWrappers.forArray(i2)) > 0); + + assertEquals(0, ArrayWrappers.forArray(s1).compareTo(ArrayWrappers.forArray(s1))); + assertTrue(ArrayWrappers.forArray(s1).compareTo(ArrayWrappers.forArray(s2)) > 0); + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType1.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType1.java new file mode 100644 index 000000000000..92b643b0cb92 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType1.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import com.google.common.base.Objects; + +public class CustomType1 { + + @KVIndex + public String key; + + @KVIndex("id") + public String id; + + @KVIndex(value = "name", copy = true) + public String name; + + @KVIndex("int") + public int num; + + @KVIndex(value = "child", parent = "id") + public String child; + + @Override + public boolean equals(Object o) { + if (o instanceof CustomType1) { + CustomType1 other = (CustomType1) o; + return id.equals(other.id) && name.equals(other.name); + } + return false; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("key", key) + .add("id", id) + .add("name", name) + .add("num", num) + .toString(); + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java new file mode 100644 index 000000000000..9a81f86812cd --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java @@ -0,0 +1,504 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +public abstract class DBIteratorSuite { + + private static final Logger LOG = LoggerFactory.getLogger(DBIteratorSuite.class); + + private static final int MIN_ENTRIES = 42; + private static final int MAX_ENTRIES = 1024; + private static final Random RND = new Random(); + + private static List allEntries; + private static List clashingEntries; + private static KVStore db; + + private interface BaseComparator extends Comparator { + /** + * Returns a comparator that falls back to natural order if this comparator's ordering + * returns equality for two elements. Used to mimic how the index sorts things internally. + */ + default BaseComparator fallback() { + return (t1, t2) -> { + int result = BaseComparator.this.compare(t1, t2); + if (result != 0) { + return result; + } + + return t1.key.compareTo(t2.key); + }; + } + + /** Reverses the order of this comparator. */ + default BaseComparator reverse() { + return (t1, t2) -> -BaseComparator.this.compare(t1, t2); + } + } + + private static final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); + private static final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); + private static final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); + private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; + private static final BaseComparator CHILD_INDEX_ORDER = (t1, t2) -> t1.child.compareTo(t2.child); + + /** + * Implementations should override this method; it is called only once, before all tests are + * run. Any state can be safely stored in static variables and cleaned up in a @AfterClass + * handler. + */ + protected abstract KVStore createStore() throws Exception; + + @BeforeClass + public static void setupClass() { + long seed = RND.nextLong(); + LOG.info("Random seed: {}", seed); + RND.setSeed(seed); + } + + @AfterClass + public static void cleanupData() throws Exception { + allEntries = null; + db = null; + } + + @Before + public void setup() throws Exception { + if (db != null) { + return; + } + + db = createStore(); + + int count = RND.nextInt(MAX_ENTRIES) + MIN_ENTRIES; + + allEntries = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + RND.nextInt(MAX_ENTRIES); + t.num = RND.nextInt(MAX_ENTRIES); + t.child = "child" + (i % MIN_ENTRIES); + allEntries.add(t); + } + + // Shuffle the entries to avoid the insertion order matching the natural ordering. Just in case. + Collections.shuffle(allEntries, RND); + for (CustomType1 e : allEntries) { + db.write(e); + } + + // Pick the first generated value, and forcefully create a few entries that will clash + // with the indexed values (id and name), to make sure the index behaves correctly when + // multiple entities are indexed by the same value. + // + // This also serves as a test for the test code itself, to make sure it's sorting indices + // the same way the store is expected to. + CustomType1 first = allEntries.get(0); + clashingEntries = new ArrayList<>(); + + int clashCount = RND.nextInt(MIN_ENTRIES) + 1; + for (int i = 0; i < clashCount; i++) { + CustomType1 t = new CustomType1(); + t.key = "n-key" + (count + i); + t.id = first.id; + t.name = first.name; + t.num = first.num; + t.child = first.child; + allEntries.add(t); + clashingEntries.add(t); + db.write(t); + } + + // Create another entry that could cause problems: take the first entry, and make its indexed + // name be an extension of the existing ones, to make sure the implementation sorts these + // correctly even considering the separator character (shorter strings first). + CustomType1 t = new CustomType1(); + t.key = "extended-key-0"; + t.id = first.id; + t.name = first.name + "a"; + t.num = first.num; + t.child = first.child; + allEntries.add(t); + db.write(t); + } + + @Test + public void naturalIndex() throws Exception { + testIteration(NATURAL_ORDER, view(), null, null); + } + + @Test + public void refIndex() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id"), null, null); + } + + @Test + public void copyIndex() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name"), null, null); + } + + @Test + public void numericIndex() throws Exception { + testIteration(NUMERIC_INDEX_ORDER, view().index("int"), null, null); + } + + @Test + public void childIndex() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id), null, null); + } + + @Test + public void naturalIndexDescending() throws Exception { + testIteration(NATURAL_ORDER, view().reverse(), null, null); + } + + @Test + public void refIndexDescending() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id").reverse(), null, null); + } + + @Test + public void copyIndexDescending() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").reverse(), null, null); + } + + @Test + public void numericIndexDescending() throws Exception { + testIteration(NUMERIC_INDEX_ORDER, view().index("int").reverse(), null, null); + } + + @Test + public void childIndexDescending() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).reverse(), null, null); + } + + @Test + public void naturalIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().first(first.key), first, null); + } + + @Test + public void refIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").first(first.id), first, null); + } + + @Test + public void copyIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").first(first.name), first, null); + } + + @Test + public void numericIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").first(first.num), first, null); + } + + @Test + public void childIndexWithStart() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).first(any.child), null, + null); + } + + @Test + public void naturalIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().first(first.key), first, null); + } + + @Test + public void refIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").first(first.id), first, null); + } + + @Test + public void copyIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").first(first.name), first, null); + } + + @Test + public void numericIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").first(first.num), first, null); + } + + @Test + public void childIndexDescendingWithStart() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, + view().index("child").parent(any.id).first(any.child).reverse(), null, null); + } + + @Test + public void naturalIndexWithSkip() throws Exception { + testIteration(NATURAL_ORDER, view().skip(pickCount()), null, null); + } + + @Test + public void refIndexWithSkip() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id").skip(pickCount()), null, null); + } + + @Test + public void copyIndexWithSkip() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").skip(pickCount()), null, null); + } + + @Test + public void childIndexWithSkip() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).skip(pickCount()), + null, null); + } + + @Test + public void naturalIndexWithMax() throws Exception { + testIteration(NATURAL_ORDER, view().max(pickCount()), null, null); + } + + @Test + public void copyIndexWithMax() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").max(pickCount()), null, null); + } + + @Test + public void childIndexWithMax() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).max(pickCount()), null, + null); + } + + @Test + public void naturalIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().last(last.key), null, last); + } + + @Test + public void refIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").last(last.name), null, last); + } + + @Test + public void numericIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").last(last.num), null, last); + } + + @Test + public void childIndexWithLast() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).last(any.child), null, + null); + } + + @Test + public void naturalIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().last(last.key), null, last); + } + + @Test + public void refIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").last(last.name), + null, last); + } + + @Test + public void numericIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").last(last.num), + null, last); + } + + @Test + public void childIndexDescendingWithLast() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).last(any.child).reverse(), + null, null); + } + + @Test + public void testRefWithIntNaturalKey() throws Exception { + LevelDBSuite.IntKeyType i = new LevelDBSuite.IntKeyType(); + i.key = 1; + i.id = "1"; + i.values = Arrays.asList("1"); + + db.write(i); + + try(KVStoreIterator it = db.view(i.getClass()).closeableIterator()) { + Object read = it.next(); + assertEquals(i, read); + } + } + + private CustomType1 pickLimit() { + // Picks an element that has clashes with other elements in the given index. + return clashingEntries.get(RND.nextInt(clashingEntries.size())); + } + + private int pickCount() { + int count = RND.nextInt(allEntries.size() / 2); + return Math.max(count, 1); + } + + /** + * Compares the two values and falls back to comparing the natural key of CustomType1 + * if they're the same, to mimic the behavior of the indexing code. + */ + private > int compareWithFallback( + T v1, + T v2, + CustomType1 ct1, + CustomType1 ct2) { + int result = v1.compareTo(v2); + if (result != 0) { + return result; + } + + return ct1.key.compareTo(ct2.key); + } + + private void testIteration( + final BaseComparator order, + final KVStoreView params, + final CustomType1 first, + final CustomType1 last) throws Exception { + List indexOrder = sortBy(order.fallback()); + if (!params.ascending) { + indexOrder = Lists.reverse(indexOrder); + } + + Iterable expected = indexOrder; + BaseComparator expectedOrder = params.ascending ? order : order.reverse(); + + if (params.parent != null) { + expected = Iterables.filter(expected, v -> params.parent.equals(v.id)); + } + + if (first != null) { + expected = Iterables.filter(expected, v -> expectedOrder.compare(first, v) <= 0); + } + + if (last != null) { + expected = Iterables.filter(expected, v -> expectedOrder.compare(v, last) <= 0); + } + + if (params.skip > 0) { + expected = Iterables.skip(expected, (int) params.skip); + } + + if (params.max != Long.MAX_VALUE) { + expected = Iterables.limit(expected, (int) params.max); + } + + List actual = collect(params); + compareLists(expected, actual); + } + + /** Could use assertEquals(), but that creates hard to read errors for large lists. */ + private void compareLists(Iterable expected, List actual) { + Iterator expectedIt = expected.iterator(); + Iterator actualIt = actual.iterator(); + + int count = 0; + while (expectedIt.hasNext()) { + if (!actualIt.hasNext()) { + break; + } + count++; + assertEquals(expectedIt.next(), actualIt.next()); + } + + String message; + Object[] remaining; + int expectedCount = count; + int actualCount = count; + + if (expectedIt.hasNext()) { + remaining = Iterators.toArray(expectedIt, Object.class); + expectedCount += remaining.length; + message = "missing"; + } else { + remaining = Iterators.toArray(actualIt, Object.class); + actualCount += remaining.length; + message = "stray"; + } + + assertEquals(String.format("Found %s elements: %s", message, Arrays.asList(remaining)), + expectedCount, actualCount); + } + + private KVStoreView view() throws Exception { + return db.view(CustomType1.class); + } + + private List collect(KVStoreView view) throws Exception { + return Arrays.asList(Iterables.toArray(view, CustomType1.class)); + } + + private List sortBy(Comparator comp) { + List copy = new ArrayList<>(allEntries); + Collections.sort(copy, comp); + return copy; + } + +} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryIteratorSuite.java similarity index 68% rename from repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala rename to common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryIteratorSuite.java index fba321be9188..27dde6a9fbea 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryIteratorSuite.java @@ -15,23 +15,13 @@ * limitations under the License. */ -package org.apache.spark.repl +package org.apache.spark.util.kvstore; -import org.apache.spark.internal.Logging +public class InMemoryIteratorSuite extends DBIteratorSuite { -object Main extends Logging { - - initializeLogIfNecessary(true) - Signaling.cancelOnInterrupt() - - private var _interp: SparkILoop = _ - - def interp = _interp - - def interp_=(i: SparkILoop) { _interp = i } - - def main(args: Array[String]) { - _interp = new SparkILoop - _interp.process(args) + @Override + protected KVStore createStore() { + return new InMemoryStore(); } + } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java new file mode 100644 index 000000000000..510b3058a4e3 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.util.NoSuchElementException; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class InMemoryStoreSuite { + + @Test + public void testObjectWriteReadDelete() throws Exception { + KVStore store = new InMemoryStore(); + + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + + try { + store.read(CustomType1.class, t.key); + fail("Expected exception for non-existant object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + store.write(t); + assertEquals(t, store.read(t.getClass(), t.key)); + assertEquals(1L, store.count(t.getClass())); + + store.delete(t.getClass(), t.key); + try { + store.read(t.getClass(), t.key); + fail("Expected exception for deleted object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + } + + @Test + public void testMultipleObjectWriteReadDelete() throws Exception { + KVStore store = new InMemoryStore(); + + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.id = "id"; + t1.name = "name1"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.id = "id"; + t2.name = "name2"; + + store.write(t1); + store.write(t2); + + assertEquals(t1, store.read(t1.getClass(), t1.key)); + assertEquals(t2, store.read(t2.getClass(), t2.key)); + assertEquals(2L, store.count(t1.getClass())); + + store.delete(t1.getClass(), t1.key); + assertEquals(t2, store.read(t2.getClass(), t2.key)); + store.delete(t2.getClass(), t2.key); + try { + store.read(t2.getClass(), t2.key); + fail("Expected exception for deleted object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + } + + @Test + public void testMetadata() throws Exception { + KVStore store = new InMemoryStore(); + assertNull(store.getMetadata(CustomType1.class)); + + CustomType1 t = new CustomType1(); + t.id = "id"; + t.name = "name"; + + store.setMetadata(t); + assertEquals(t, store.getMetadata(CustomType1.class)); + + store.setMetadata(null); + assertNull(store.getMetadata(CustomType1.class)); + } + + @Test + public void testUpdate() throws Exception { + KVStore store = new InMemoryStore(); + + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + + store.write(t); + + t.name = "anotherName"; + + store.write(t); + assertEquals(1, store.count(t.getClass())); + assertSame(t, store.read(t.getClass(), t.key)); + } + + @Test + public void testArrayIndices() throws Exception { + KVStore store = new InMemoryStore(); + + ArrayKeyIndexType o = new ArrayKeyIndexType(); + o.key = new int[] { 1, 2 }; + o.id = new String[] { "3", "4" }; + + store.write(o); + assertEquals(o, store.read(ArrayKeyIndexType.class, o.key)); + assertEquals(o, store.view(ArrayKeyIndexType.class).index("id").first(o.id).iterator().next()); + } + + @Test + public void testBasicIteration() throws Exception { + KVStore store = new InMemoryStore(); + + CustomType1 t1 = new CustomType1(); + t1.key = "1"; + t1.id = "id1"; + t1.name = "name1"; + store.write(t1); + + CustomType1 t2 = new CustomType1(); + t2.key = "2"; + t2.id = "id2"; + t2.name = "name2"; + store.write(t2); + + assertEquals(t1.id, store.view(t1.getClass()).iterator().next().id); + assertEquals(t2.id, store.view(t1.getClass()).skip(1).iterator().next().id); + assertEquals(t2.id, store.view(t1.getClass()).skip(1).max(1).iterator().next().id); + assertEquals(t1.id, + store.view(t1.getClass()).first(t1.key).max(1).iterator().next().id); + assertEquals(t2.id, + store.view(t1.getClass()).first(t2.key).max(1).iterator().next().id); + assertFalse(store.view(t1.getClass()).first(t2.id).skip(1).iterator().hasNext()); + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBBenchmark.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBBenchmark.java new file mode 100644 index 000000000000..268d025f5f06 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBBenchmark.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Slf4jReporter; +import com.codahale.metrics.Snapshot; +import com.codahale.metrics.Timer; +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +/** + * A set of small benchmarks for the LevelDB implementation. + * + * The benchmarks are run over two different types (one with just a natural index, and one + * with a ref index), over a set of 2^20 elements, and the following tests are performed: + * + * - write (then update) elements in sequential natural key order + * - write (then update) elements in random natural key order + * - iterate over natural index, ascending and descending + * - iterate over ref index, ascending and descending + */ +@Ignore +public class LevelDBBenchmark { + + private static final int COUNT = 1024; + private static final AtomicInteger IDGEN = new AtomicInteger(); + private static final MetricRegistry metrics = new MetricRegistry(); + private static final Timer dbCreation = metrics.timer("dbCreation"); + private static final Timer dbClose = metrics.timer("dbClose"); + + private LevelDB db; + private File dbpath; + + @Before + public void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + try(Timer.Context ctx = dbCreation.time()) { + db = new LevelDB(dbpath); + } + } + + @After + public void cleanup() throws Exception { + if (db != null) { + try(Timer.Context ctx = dbClose.time()) { + db.close(); + } + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @AfterClass + public static void report() { + if (metrics.getTimers().isEmpty()) { + return; + } + + int headingPrefix = 0; + for (Map.Entry e : metrics.getTimers().entrySet()) { + headingPrefix = Math.max(e.getKey().length(), headingPrefix); + } + headingPrefix += 4; + + StringBuilder heading = new StringBuilder(); + for (int i = 0; i < headingPrefix; i++) { + heading.append(" "); + } + heading.append("\tcount"); + heading.append("\tmean"); + heading.append("\tmin"); + heading.append("\tmax"); + heading.append("\t95th"); + System.out.println(heading); + + for (Map.Entry e : metrics.getTimers().entrySet()) { + StringBuilder row = new StringBuilder(); + row.append(e.getKey()); + for (int i = 0; i < headingPrefix - e.getKey().length(); i++) { + row.append(" "); + } + + Snapshot s = e.getValue().getSnapshot(); + row.append("\t").append(e.getValue().getCount()); + row.append("\t").append(toMs(s.getMean())); + row.append("\t").append(toMs(s.getMin())); + row.append("\t").append(toMs(s.getMax())); + row.append("\t").append(toMs(s.get95thPercentile())); + + System.out.println(row); + } + + Slf4jReporter.forRegistry(metrics).outputTo(LoggerFactory.getLogger(LevelDBBenchmark.class)) + .build().report(); + } + + private static String toMs(double nanos) { + return String.format("%.3f", nanos / 1000 / 1000); + } + + @Test + public void sequentialWritesNoIndex() throws Exception { + List entries = createSimpleType(); + writeAll(entries, "sequentialWritesNoIndex"); + writeAll(entries, "sequentialUpdatesNoIndex"); + deleteNoIndex(entries, "sequentialDeleteNoIndex"); + } + + @Test + public void randomWritesNoIndex() throws Exception { + List entries = createSimpleType(); + + Collections.shuffle(entries); + writeAll(entries, "randomWritesNoIndex"); + + Collections.shuffle(entries); + writeAll(entries, "randomUpdatesNoIndex"); + + Collections.shuffle(entries); + deleteNoIndex(entries, "randomDeletesNoIndex"); + } + + @Test + public void sequentialWritesIndexedType() throws Exception { + List entries = createIndexedType(); + writeAll(entries, "sequentialWritesIndexed"); + writeAll(entries, "sequentialUpdatesIndexed"); + deleteIndexed(entries, "sequentialDeleteIndexed"); + } + + @Test + public void randomWritesIndexedTypeAndIteration() throws Exception { + List entries = createIndexedType(); + + Collections.shuffle(entries); + writeAll(entries, "randomWritesIndexed"); + + Collections.shuffle(entries); + writeAll(entries, "randomUpdatesIndexed"); + + // Run iteration benchmarks here since we've gone through the trouble of writing all + // the data already. + KVStoreView view = db.view(IndexedType.class); + iterate(view, "naturalIndex"); + iterate(view.reverse(), "naturalIndexDescending"); + iterate(view.index("name"), "refIndex"); + iterate(view.index("name").reverse(), "refIndexDescending"); + + Collections.shuffle(entries); + deleteIndexed(entries, "randomDeleteIndexed"); + } + + private void iterate(KVStoreView view, String name) throws Exception { + Timer create = metrics.timer(name + "CreateIterator"); + Timer iter = metrics.timer(name + "Iteration"); + KVStoreIterator it = null; + { + // Create the iterator several times, just to have multiple data points. + for (int i = 0; i < 1024; i++) { + if (it != null) { + it.close(); + } + try(Timer.Context ctx = create.time()) { + it = view.closeableIterator(); + } + } + } + + for (; it.hasNext(); ) { + try(Timer.Context ctx = iter.time()) { + it.next(); + } + } + } + + private void writeAll(List entries, String timerName) throws Exception { + Timer timer = newTimer(timerName); + for (Object o : entries) { + try(Timer.Context ctx = timer.time()) { + db.write(o); + } + } + } + + private void deleteNoIndex(List entries, String timerName) throws Exception { + Timer delete = newTimer(timerName); + for (SimpleType i : entries) { + try(Timer.Context ctx = delete.time()) { + db.delete(i.getClass(), i.key); + } + } + } + + private void deleteIndexed(List entries, String timerName) throws Exception { + Timer delete = newTimer(timerName); + for (IndexedType i : entries) { + try(Timer.Context ctx = delete.time()) { + db.delete(i.getClass(), i.key); + } + } + } + + private List createSimpleType() { + List entries = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + SimpleType t = new SimpleType(); + t.key = IDGEN.getAndIncrement(); + t.name = "name" + (t.key % 1024); + entries.add(t); + } + return entries; + } + + private List createIndexedType() { + List entries = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + IndexedType t = new IndexedType(); + t.key = IDGEN.getAndIncrement(); + t.name = "name" + (t.key % 1024); + entries.add(t); + } + return entries; + } + + private Timer newTimer(String name) { + assertNull("Timer already exists: " + name, metrics.getTimers().get(name)); + return metrics.timer(name); + } + + public static class SimpleType { + + @KVIndex + public int key; + + public String name; + + } + + public static class IndexedType { + + @KVIndex + public int key; + + @KVIndex("name") + public String name; + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBIteratorSuite.java new file mode 100644 index 000000000000..f8195da58cf9 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBIteratorSuite.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.io.File; + +import org.apache.commons.io.FileUtils; +import org.junit.AfterClass; + +public class LevelDBIteratorSuite extends DBIteratorSuite { + + private static File dbpath; + private static LevelDB db; + + @AfterClass + public static void cleanup() throws Exception { + if (db != null) { + db.close(); + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @Override + protected KVStore createStore() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + db = new LevelDB(dbpath); + return db; + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java new file mode 100644 index 000000000000..2b07d249d202 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.NoSuchElementException; + +import org.apache.commons.io.FileUtils; +import org.iq80.leveldb.DBIterator; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBSuite { + + private LevelDB db; + private File dbpath; + + @After + public void cleanup() throws Exception { + if (db != null) { + db.close(); + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @Before + public void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + db = new LevelDB(dbpath); + } + + @Test + public void testReopenAndVersionCheckDb() throws Exception { + db.close(); + db = null; + assertTrue(dbpath.exists()); + + db = new LevelDB(dbpath); + assertEquals(LevelDB.STORE_VERSION, + db.serializer.deserializeLong(db.db().get(LevelDB.STORE_VERSION_KEY))); + db.db().put(LevelDB.STORE_VERSION_KEY, db.serializer.serialize(LevelDB.STORE_VERSION + 1)); + db.close(); + db = null; + + try { + db = new LevelDB(dbpath); + fail("Should have failed version check."); + } catch (UnsupportedStoreVersionException e) { + // Expected. + } + } + + @Test + public void testObjectWriteReadDelete() throws Exception { + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + t.child = "child"; + + try { + db.read(CustomType1.class, t.key); + fail("Expected exception for non-existant object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + db.write(t); + assertEquals(t, db.read(t.getClass(), t.key)); + assertEquals(1L, db.count(t.getClass())); + + db.delete(t.getClass(), t.key); + try { + db.read(t.getClass(), t.key); + fail("Expected exception for deleted object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + // Look into the actual DB and make sure that all the keys related to the type have been + // removed. + assertEquals(0, countKeys(t.getClass())); + } + + @Test + public void testMultipleObjectWriteReadDelete() throws Exception { + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.id = "id"; + t1.name = "name1"; + t1.child = "child1"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.id = "id"; + t2.name = "name2"; + t2.child = "child2"; + + db.write(t1); + db.write(t2); + + assertEquals(t1, db.read(t1.getClass(), t1.key)); + assertEquals(t2, db.read(t2.getClass(), t2.key)); + assertEquals(2L, db.count(t1.getClass())); + + // There should be one "id" index entry with two values. + assertEquals(2, db.count(t1.getClass(), "id", t1.id)); + + // Delete the first entry; now there should be 3 remaining keys, since one of the "name" + // index entries should have been removed. + db.delete(t1.getClass(), t1.key); + + // Make sure there's a single entry in the "id" index now. + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + + // Delete the remaining entry, make sure all data is gone. + db.delete(t2.getClass(), t2.key); + assertEquals(0, countKeys(t2.getClass())); + } + + @Test + public void testMultipleTypesWriteReadDelete() throws Exception { + CustomType1 t1 = new CustomType1(); + t1.key = "1"; + t1.id = "id"; + t1.name = "name1"; + t1.child = "child1"; + + IntKeyType t2 = new IntKeyType(); + t2.key = 2; + t2.id = "2"; + t2.values = Arrays.asList("value1", "value2"); + + ArrayKeyIndexType t3 = new ArrayKeyIndexType(); + t3.key = new int[] { 42, 84 }; + t3.id = new String[] { "id1", "id2" }; + + db.write(t1); + db.write(t2); + db.write(t3); + + assertEquals(t1, db.read(t1.getClass(), t1.key)); + assertEquals(t2, db.read(t2.getClass(), t2.key)); + assertEquals(t3, db.read(t3.getClass(), t3.key)); + + // There should be one "id" index with a single entry for each type. + assertEquals(1, db.count(t1.getClass(), "id", t1.id)); + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + assertEquals(1, db.count(t3.getClass(), "id", t3.id)); + + // Delete the first entry; this should not affect the entries for the second type. + db.delete(t1.getClass(), t1.key); + assertEquals(0, countKeys(t1.getClass())); + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + assertEquals(1, db.count(t3.getClass(), "id", t3.id)); + + // Delete the remaining entries, make sure all data is gone. + db.delete(t2.getClass(), t2.key); + assertEquals(0, countKeys(t2.getClass())); + + db.delete(t3.getClass(), t3.key); + assertEquals(0, countKeys(t3.getClass())); + } + + @Test + public void testMetadata() throws Exception { + assertNull(db.getMetadata(CustomType1.class)); + + CustomType1 t = new CustomType1(); + t.id = "id"; + t.name = "name"; + t.child = "child"; + + db.setMetadata(t); + assertEquals(t, db.getMetadata(CustomType1.class)); + + db.setMetadata(null); + assertNull(db.getMetadata(CustomType1.class)); + } + + @Test + public void testUpdate() throws Exception { + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + t.child = "child"; + + db.write(t); + + t.name = "anotherName"; + + db.write(t); + + assertEquals(1, db.count(t.getClass())); + assertEquals(1, db.count(t.getClass(), "name", "anotherName")); + assertEquals(0, db.count(t.getClass(), "name", "name")); + } + + @Test + public void testSkip() throws Exception { + for (int i = 0; i < 10; i++) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + i; + t.child = "child" + i; + + db.write(t); + } + + KVStoreIterator it = db.view(CustomType1.class).closeableIterator(); + assertTrue(it.hasNext()); + assertTrue(it.skip(5)); + assertEquals("key5", it.next().key); + assertTrue(it.skip(3)); + assertEquals("key9", it.next().key); + assertFalse(it.hasNext()); + } + + private int countKeys(Class type) throws Exception { + byte[] prefix = db.getTypeInfo(type).keyPrefix(); + int count = 0; + + DBIterator it = db.db().iterator(); + it.seek(prefix); + + while (it.hasNext()) { + byte[] key = it.next().getKey(); + if (LevelDBIterator.startsWith(key, prefix)) { + count++; + } + } + + return count; + } + + public static class IntKeyType { + + @KVIndex + public int key; + + @KVIndex("id") + public String id; + + public List values; + + @Override + public boolean equals(Object o) { + if (o instanceof IntKeyType) { + IntKeyType other = (IntKeyType) o; + return key == other.key && id.equals(other.id) && values.equals(other.values); + } + return false; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBTypeInfoSuite.java new file mode 100644 index 000000000000..38db3bedaef6 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBTypeInfoSuite.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBTypeInfoSuite { + + @Test + public void testIndexAnnotation() throws Exception { + KVTypeInfo ti = new KVTypeInfo(CustomType1.class); + assertEquals(5, ti.indices().count()); + + CustomType1 t1 = new CustomType1(); + t1.key = "key"; + t1.id = "id"; + t1.name = "name"; + t1.num = 42; + t1.child = "child"; + + assertEquals(t1.key, ti.getIndexValue(KVIndex.NATURAL_INDEX_NAME, t1)); + assertEquals(t1.id, ti.getIndexValue("id", t1)); + assertEquals(t1.name, ti.getIndexValue("name", t1)); + assertEquals(t1.num, ti.getIndexValue("int", t1)); + assertEquals(t1.child, ti.getIndexValue("child", t1)); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoNaturalIndex() throws Exception { + newTypeInfo(NoNaturalIndex.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoNaturalIndex2() throws Exception { + newTypeInfo(NoNaturalIndex2.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testDuplicateIndex() throws Exception { + newTypeInfo(DuplicateIndex.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyIndexName() throws Exception { + newTypeInfo(EmptyIndexName.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalIndexName() throws Exception { + newTypeInfo(IllegalIndexName.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalIndexMethod() throws Exception { + newTypeInfo(IllegalIndexMethod.class); + } + + @Test + public void testKeyClashes() throws Exception { + LevelDBTypeInfo ti = newTypeInfo(CustomType1.class); + + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.name = "a"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.name = "aa"; + + CustomType1 t3 = new CustomType1(); + t3.key = "key3"; + t3.name = "aaa"; + + // Make sure entries with conflicting names are sorted correctly. + assertBefore(ti.index("name").entityKey(null, t1), ti.index("name").entityKey(null, t2)); + assertBefore(ti.index("name").entityKey(null, t1), ti.index("name").entityKey(null, t3)); + assertBefore(ti.index("name").entityKey(null, t2), ti.index("name").entityKey(null, t3)); + } + + @Test + public void testNumEncoding() throws Exception { + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + + assertEquals("+=00000001", new String(idx.toKey(1), UTF_8)); + assertEquals("+=00000010", new String(idx.toKey(16), UTF_8)); + assertEquals("+=7fffffff", new String(idx.toKey(Integer.MAX_VALUE), UTF_8)); + + assertBefore(idx.toKey(1), idx.toKey(2)); + assertBefore(idx.toKey(-1), idx.toKey(2)); + assertBefore(idx.toKey(-11), idx.toKey(2)); + assertBefore(idx.toKey(-11), idx.toKey(-1)); + assertBefore(idx.toKey(1), idx.toKey(11)); + assertBefore(idx.toKey(Integer.MIN_VALUE), idx.toKey(Integer.MAX_VALUE)); + + assertBefore(idx.toKey(1L), idx.toKey(2L)); + assertBefore(idx.toKey(-1L), idx.toKey(2L)); + assertBefore(idx.toKey(Long.MIN_VALUE), idx.toKey(Long.MAX_VALUE)); + + assertBefore(idx.toKey((short) 1), idx.toKey((short) 2)); + assertBefore(idx.toKey((short) -1), idx.toKey((short) 2)); + assertBefore(idx.toKey(Short.MIN_VALUE), idx.toKey(Short.MAX_VALUE)); + + assertBefore(idx.toKey((byte) 1), idx.toKey((byte) 2)); + assertBefore(idx.toKey((byte) -1), idx.toKey((byte) 2)); + assertBefore(idx.toKey(Byte.MIN_VALUE), idx.toKey(Byte.MAX_VALUE)); + + byte prefix = LevelDBTypeInfo.ENTRY_PREFIX; + assertSame(new byte[] { prefix, LevelDBTypeInfo.FALSE }, idx.toKey(false)); + assertSame(new byte[] { prefix, LevelDBTypeInfo.TRUE }, idx.toKey(true)); + } + + @Test + public void testArrayIndices() throws Exception { + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + + assertBefore(idx.toKey(new String[] { "str1" }), idx.toKey(new String[] { "str2" })); + assertBefore(idx.toKey(new String[] { "str1", "str2" }), + idx.toKey(new String[] { "str1", "str3" })); + + assertBefore(idx.toKey(new int[] { 1 }), idx.toKey(new int[] { 2 })); + assertBefore(idx.toKey(new int[] { 1, 2 }), idx.toKey(new int[] { 1, 3 })); + } + + private LevelDBTypeInfo newTypeInfo(Class type) throws Exception { + return new LevelDBTypeInfo(null, type, type.getName().getBytes(UTF_8)); + } + + private void assertBefore(byte[] key1, byte[] key2) { + assertBefore(new String(key1, UTF_8), new String(key2, UTF_8)); + } + + private void assertBefore(String str1, String str2) { + assertTrue(String.format("%s < %s failed", str1, str2), str1.compareTo(str2) < 0); + } + + private void assertSame(byte[] key1, byte[] key2) { + assertEquals(new String(key1, UTF_8), new String(key2, UTF_8)); + } + + public static class NoNaturalIndex { + + public String id; + + } + + public static class NoNaturalIndex2 { + + @KVIndex("id") + public String id; + + } + + public static class DuplicateIndex { + + @KVIndex + public String key; + + @KVIndex("id") + public String id; + + @KVIndex("id") + public String id2; + + } + + public static class EmptyIndexName { + + @KVIndex("") + public String id; + + } + + public static class IllegalIndexName { + + @KVIndex("__invalid") + public String id; + + } + + public static class IllegalIndexMethod { + + @KVIndex("id") + public String id(boolean illegalParam) { + return null; + } + + } + +} diff --git a/dev/change-version-to-2.10.sh b/common/kvstore/src/test/resources/log4j.properties old mode 100755 new mode 100644 similarity index 63% rename from dev/change-version-to-2.10.sh rename to common/kvstore/src/test/resources/log4j.properties index 0962d34c52f2..e8da774f7ca9 --- a/dev/change-version-to-2.10.sh +++ b/common/kvstore/src/test/resources/log4j.properties @@ -1,5 +1,3 @@ -#!/usr/bin/env bash - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -17,7 +15,13 @@ # limitations under the License. # -# This script exists for backwards compability. Use change-scala-version.sh instead. -echo "This script is deprecated. Please instead run: change-scala-version.sh 2.10" +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n -$(dirname $0)/change-scala-version.sh 2.10 +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 066970f24205..18cbdadd224a 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -61,6 +61,11 @@ jackson-annotations + + io.dropwizard.metrics + metrics-core + + org.slf4j @@ -90,7 +95,8 @@ org.apache.spark spark-tags_${scala.binary.version} - + test + - -XDignore.symbol.file - - - - - org.apache.maven.plugins - maven-compiler-plugin - 3.6.1 - - - - -XDignore.symbol.file - - - - - + + + net.alchim31.maven + scala-maven-plugin + + + + -XDignore.symbol.file + + + + diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 045fec33a282..fd1906d2e5ae 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -365,7 +365,7 @@ private void writeObject(ObjectOutputStream out) throws IOException { this.writeTo(out); } - private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + private void readObject(ObjectInputStream in) throws IOException { this.readFrom0(in); } } diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 680d0413b161..a3772a262008 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -93,31 +93,17 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - - net.alchim31.maven - scala-maven-plugin - 3.2.2 - - - - -XDignore.symbol.file - - - - - org.apache.maven.plugins - maven-compiler-plugin - 3.6.1 - - - - -XDignore.symbol.file - - - - - + + + net.alchim31.maven + scala-maven-plugin + + + + -XDignore.symbol.file + + + + diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 1321b8318115..aca6fca00c48 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -48,7 +48,8 @@ public final class Platform { boolean _unaligned; String arch = System.getProperty("os.arch", ""); if (arch.equals("ppc64le") || arch.equals("ppc64")) { - // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but + // ppc64 and ppc64le support it _unaligned = true; } else { try { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 1a3cdff63826..2cd39bd60c2a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -39,7 +39,7 @@ public final class LongArray { private final long length; public LongArray(MemoryBlock memory) { - assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; + assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; this.baseObj = memory.getBaseObject(); this.baseOffset = memory.getBaseOffset(); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 3ced2094f5e6..7ced13d35723 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -44,7 +44,7 @@ public static long getPrefix(byte[] bytes) { final int minLen = Math.min(bytes.length, 8); long p = 0; for (int i = 0; i < minLen; ++i) { - p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) + p |= ((long) Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i) & 0xff) << (56 - 8 * i); } return p; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 5437e998c085..b0d0c44823e6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -64,7 +64,8 @@ public final class UTF8String implements Comparable, Externalizable, 5, 5, 5, 5, 6, 6}; - private static boolean isLittleEndian = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + private static final boolean IS_LITTLE_ENDIAN = + ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); @@ -220,7 +221,7 @@ public long getPrefix() { // After getting the data, we use a mask to mask out data that is not part of the string. long p; long mask = 0; - if (isLittleEndian) { + if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { p = Platform.getLong(base, offset); } else if (numBytes > 4) { @@ -497,16 +498,30 @@ private UTF8String copyUTF8String(int start, int end) { public UTF8String trim() { int s = 0; - int e = this.numBytes - 1; // skip all of the space (0x20) in the left side while (s < this.numBytes && getByte(s) == 0x20) s++; - // skip all of the space (0x20) in the right side - while (e >= 0 && getByte(e) == 0x20) e--; - if (s > e) { + if (s == this.numBytes) { // empty string return EMPTY_UTF8; + } + // skip all of the space (0x20) in the right side + int e = this.numBytes - 1; + while (e > s && getByte(e) == 0x20) e--; + return copyUTF8String(s, e); + } + + /** + * Based on the given trim string, trim this string starting from both ends + * This method searches for each character in the source string, removes the character if it is + * found in the trim string, stops at the first not found. It calls the trimLeft first, then + * trimRight. It returns a new string in which both ends trim characters have been removed. + * @param trimString the trim character string + */ + public UTF8String trim(UTF8String trimString) { + if (trimString != null) { + return trimLeft(trimString).trimRight(trimString); } else { - return copyUTF8String(s, e); + return null; } } @@ -522,6 +537,42 @@ public UTF8String trimLeft() { } } + /** + * Based on the given trim string, trim this string starting from left end + * This method searches each character in the source string starting from the left end, removes + * the character if it is in the trim string, stops at the first character which is not in the + * trim string, returns the new string. + * @param trimString the trim character string + */ + public UTF8String trimLeft(UTF8String trimString) { + if (trimString == null) return null; + // the searching byte position in the source string + int srchIdx = 0; + // the first beginning byte position of a non-matching character + int trimIdx = 0; + + while (srchIdx < numBytes) { + UTF8String searchChar = copyUTF8String( + srchIdx, srchIdx + numBytesForFirstByte(this.getByte(srchIdx)) - 1); + int searchCharBytes = searchChar.numBytes; + // try to find the matching for the searchChar in the trimString set + if (trimString.find(searchChar, 0) >= 0) { + trimIdx += searchCharBytes; + } else { + // no matching, exit the search + break; + } + srchIdx += searchCharBytes; + } + + if (trimIdx >= numBytes) { + // empty string + return EMPTY_UTF8; + } else { + return copyUTF8String(trimIdx, numBytes - 1); + } + } + public UTF8String trimRight() { int e = numBytes - 1; // skip all of the space (0x20) in the right side @@ -535,6 +586,53 @@ public UTF8String trimRight() { } } + /** + * Based on the given trim string, trim this string starting from right end + * This method searches each character in the source string starting from the right end, + * removes the character if it is in the trim string, stops at the first character which is not + * in the trim string, returns the new string. + * @param trimString the trim character string + */ + public UTF8String trimRight(UTF8String trimString) { + if (trimString == null) return null; + int charIdx = 0; + // number of characters from the source string + int numChars = 0; + // array of character length for the source string + int[] stringCharLen = new int[numBytes]; + // array of the first byte position for each character in the source string + int[] stringCharPos = new int[numBytes]; + // build the position and length array + while (charIdx < numBytes) { + stringCharPos[numChars] = charIdx; + stringCharLen[numChars] = numBytesForFirstByte(getByte(charIdx)); + charIdx += stringCharLen[numChars]; + numChars ++; + } + + // index trimEnd points to the first no matching byte position from the right side of + // the source string. + int trimEnd = numBytes - 1; + while (numChars > 0) { + UTF8String searchChar = copyUTF8String( + stringCharPos[numChars - 1], + stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + if (trimString.find(searchChar, 0) >= 0) { + trimEnd -= stringCharLen[numChars - 1]; + } else { + break; + } + numChars --; + } + + if (trimEnd < 0) { + // empty string + return EMPTY_UTF8; + } else { + return copyUTF8String(0, trimEnd); + } + } + public UTF8String reverse() { byte[] result = new byte[this.numBytes]; @@ -835,6 +933,15 @@ public UTF8String[] split(UTF8String pattern, int limit) { return res; } + public UTF8String replace(UTF8String search, UTF8String replace) { + if (EMPTY_UTF8.equals(search)) { + return this; + } + String replaced = toString().replace( + search.toString(), replace.toString()); + return fromString(replaced); + } + // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes public UTF8String translate(Map dict) { String srcStr = this.toString(); @@ -854,8 +961,8 @@ public UTF8String translate(Map dict) { * Wrapper over `long` to allow result of parsing long from string to be accessed via reference. * This is done solely for better performance and is not expected to be used by end users. */ - public static class LongWrapper { - public long value = 0; + public static class LongWrapper implements Serializable { + public transient long value = 0; } /** @@ -865,8 +972,8 @@ public static class LongWrapper { * {@link LongWrapper} could have been used here but using `int` directly save the extra cost of * conversion from `long` to `int` */ - public static class IntWrapper { - public int value = 0; + public static class IntWrapper implements Serializable { + public transient int value = 0; } /** @@ -1079,13 +1186,32 @@ public UTF8String clone() { return fromBytes(getBytes()); } + public UTF8String copy() { + byte[] bytes = new byte[numBytes]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + return fromBytes(bytes); + } + @Override public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); - // TODO: compare 8 bytes as unsigned long - for (int i = 0; i < len; i ++) { + int wordMax = (len / 8) * 8; + long roffset = other.offset; + Object rbase = other.base; + for (int i = 0; i < wordMax; i += 8) { + long left = getLong(base, offset + i); + long right = getLong(rbase, roffset + i); + if (left != right) { + if (IS_LITTLE_ENDIAN) { + return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); + } else { + return Long.compareUnsigned(left, right); + } + } + } + for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (other.getByte(i) & 0xFF); + int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); if (res != 0) { return res; } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index a77ba826fce2..4ae49d82efa2 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -73,5 +73,6 @@ public void memoryDebugFillEnabledInTest() { Assert.assertEquals( Platform.getByte(offheap.getBaseObject(), offheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + MemoryAllocator.UNSAFE.free(offheap); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index c376371abdf9..9b303fa5bc6c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -222,10 +222,13 @@ public void substring() { @Test public void trims() { + assertEquals(fromString("1"), fromString("1").trim()); + assertEquals(fromString("hello"), fromString(" hello ").trim()); assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.trim()); assertEquals(EMPTY_UTF8, fromString(" ").trim()); assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); @@ -730,4 +733,62 @@ public void testToLong() throws IOException { assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); } } + + @Test + public void trimBothWithTrimString() { + assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(" "))); + assertEquals(fromString("o"), fromString(" hello ").trim(fromString(" hle"))); + assertEquals(fromString("h e"), fromString("ooh e ooo").trim(fromString("o "))); + assertEquals(fromString(""), fromString("ooo...oooo").trim(fromString("o."))); + assertEquals(fromString("b"), fromString("%^b[]@").trim(fromString("][@^%"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(" "))); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数"), fromString("a数b").trim(fromString("ab"))); + assertEquals(fromString(""), fromString("a").trim(fromString("a数b"))); + assertEquals(fromString(""), fromString("数数 数数数").trim(fromString("数 "))); + assertEquals(fromString("据砖头"), fromString("数]数[数据砖头#数数").trim(fromString("[数]#"))); + assertEquals(fromString("据砖头数数 "), fromString("数数数据砖头数数 ").trim(fromString("数"))); + } + + @Test + public void trimLeftWithTrimString() { + assertEquals(fromString(" hello "), fromString(" hello ").trimLeft(fromString(""))); + assertEquals(fromString(""), fromString("a").trimLeft(fromString("a"))); + assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); + assertEquals(fromString("ba"), fromString("ba").trimLeft(fromString("a"))); + assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a"))); + assertEquals(fromString("trim"), fromString("oabtrim").trimLeft(fromString("bao"))); + assertEquals(fromString("rim "), fromString("ooootrim ").trimLeft(fromString("otm"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft(fromString(" "))); + + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft(fromString(" "))); + assertEquals(fromString("数"), fromString("数").trimLeft(fromString("a"))); + assertEquals(fromString("a"), fromString("a").trimLeft(fromString("数"))); + assertEquals(fromString("砖头数数"), fromString("数数数据砖头数数").trimLeft(fromString("据数"))); + assertEquals(fromString("据砖头数数"), fromString(" 数数数据砖头数数").trimLeft(fromString("数 "))); + assertEquals(fromString("据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a数砖"))); + assertEquals(fromString("$S,.$BR"), fromString(",,,,%$S,.$BR").trimLeft(fromString("%,"))); + } + + @Test + public void trimRightWithTrimString() { + assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); + assertEquals(fromString(""), fromString("a").trimRight(fromString("a"))); + assertEquals(fromString("cc"), fromString("ccbaaaa").trimRight(fromString("ba"))); + assertEquals(fromString(""), fromString("aabbbbaaa").trimRight(fromString("ab"))); + assertEquals(fromString(" he"), fromString(" hello ").trimRight(fromString(" ol"))); + assertEquals(fromString("oohell"), + fromString("oohellooo../*&").trimRight(fromString("./,&%*o"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(" "))); + + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight(fromString(" "))); + assertEquals(fromString("数数砖头"), fromString("数数砖头数aa数").trimRight(fromString("a数"))); + assertEquals(fromString(""), fromString("数数数据砖ab").trimRight(fromString("数据砖ab"))); + assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); + assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); + } } diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 55cb094b4af4..2ecb4f1464a4 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -15,6 +15,6 @@ # limitations under the License. # -spark.mesos.executor.docker.image: +spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index aeb76c9b2f6e..4c008a13607c 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -118,6 +118,14 @@ # prefix EMPTY STRING Prefix to prepend to every metric's name # protocol tcp Protocol ("tcp" or "udp") to use +# org.apache.spark.metrics.sink.StatsdSink +# Name: Default: Description: +# host 127.0.0.1 Hostname or IP of StatsD server +# port 8125 Port of StatsD server +# period 10 Poll period +# unit seconds Units of poll period +# prefix EMPTY STRING Prefix to prepend to metric name + ## Examples # Enable JmxSink for all instances by class name #*.sink.jmx.class=org.apache.spark.metrics.sink.JmxSink @@ -125,6 +133,10 @@ # Enable ConsoleSink for all instances by class name #*.sink.console.class=org.apache.spark.metrics.sink.ConsoleSink +# Enable StatsdSink for all instances by class name +#*.sink.statsd.class=org.apache.spark.metrics.sink.StatsdSink +#*.sink.statsd.prefix=spark + # Polling period for the ConsoleSink #*.sink.console.period=10 # Unit of the polling period for the ConsoleSink diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 94bd2c477a35..f8c895f5303b 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -34,7 +34,7 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) +# - YARN_CONF_DIR, to point Spark towards YARN configuration files when you use YARN # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) @@ -52,6 +52,7 @@ # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") # - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") +# - SPARK_DAEMON_CLASSPATH, to set the classpath for all daemons # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers # Generic options for the daemons used in the standalone deploy mode @@ -61,3 +62,7 @@ # - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER) # - SPARK_NICENESS The scheduling priority for daemons. (Default: 0) # - SPARK_NO_DAEMONIZE Run the proposed command in the foreground. It will not output a PID file. +# Options for native BLAS, like Intel MKL, OpenBLAS, and so on. +# You might get better performance to enable these options if using native BLAS (see SPARK-21305). +# - MKL_NUM_THREADS=1 Disable multi-threading of Intel MKL +# - OPENBLAS_NUM_THREADS=1 Disable multi-threading of OpenBLAS diff --git a/core/pom.xml b/core/pom.xml index 7f245b5b6384..09669149d812 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -67,6 +67,11 @@ spark-launcher_${scala.binary.version} ${project.version} + + org.apache.spark + spark-kvstore_${scala.binary.version} + ${project.version} + org.apache.spark spark-network-common_${scala.binary.version} @@ -190,8 +195,8 @@ snappy-java - net.jpountz.lz4 - lz4 + org.lz4 + lz4-java org.roaringbitmap @@ -335,7 +340,7 @@ net.sf.py4j py4j - 0.10.4 + 0.10.6 org.apache.spark @@ -357,6 +362,34 @@ org.apache.commons commons-crypto + + + + ${hive.group} + hive-exec + provided + + + ${hive.group} + hive-metastore + provided + + + org.apache.thrift + libthrift + provided + + + org.apache.thrift + libfb303 + provided + + target/scala-${scala.binary.version}/classes @@ -401,6 +434,7 @@ + copy-dependencies package @@ -454,6 +488,7 @@ org.codehaus.mojo exec-maven-plugin + 1.6.0 sparkr-pkg diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 140c52fd12f9..3583856d8899 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -139,6 +139,11 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onSpeculativeTaskSubmitted(SparkListenerSpeculativeTaskSubmitted speculativeTask) { + onEvent(speculativeTask); + } + @Override public void onOtherEvent(SparkListenerEvent event) { onEvent(event); diff --git a/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java deleted file mode 100644 index 9d6f06ed2888..000000000000 --- a/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java +++ /dev/null @@ -1,260 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.io; - -import java.io.EOFException; -import java.io.FilterInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.util.zip.Checksum; - -import net.jpountz.lz4.LZ4Exception; -import net.jpountz.lz4.LZ4Factory; -import net.jpountz.lz4.LZ4FastDecompressor; -import net.jpountz.util.SafeUtils; -import net.jpountz.xxhash.XXHashFactory; - -/** - * {@link InputStream} implementation to decode data written with - * {@link net.jpountz.lz4.LZ4BlockOutputStream}. This class is not thread-safe and does not - * support {@link #mark(int)}/{@link #reset()}. - * @see net.jpountz.lz4.LZ4BlockOutputStream - * - * This is based on net.jpountz.lz4.LZ4BlockInputStream - * - * changes: https://github.com/davies/lz4-java/commit/cc1fa940ac57cc66a0b937300f805d37e2bf8411 - * - * TODO: merge this into upstream - */ -public final class LZ4BlockInputStream extends FilterInputStream { - - // Copied from net.jpountz.lz4.LZ4BlockOutputStream - static final byte[] MAGIC = new byte[] { 'L', 'Z', '4', 'B', 'l', 'o', 'c', 'k' }; - static final int MAGIC_LENGTH = MAGIC.length; - - static final int HEADER_LENGTH = - MAGIC_LENGTH // magic bytes - + 1 // token - + 4 // compressed length - + 4 // decompressed length - + 4; // checksum - - static final int COMPRESSION_LEVEL_BASE = 10; - - static final int COMPRESSION_METHOD_RAW = 0x10; - static final int COMPRESSION_METHOD_LZ4 = 0x20; - - static final int DEFAULT_SEED = 0x9747b28c; - - private final LZ4FastDecompressor decompressor; - private final Checksum checksum; - private byte[] buffer; - private byte[] compressedBuffer; - private int originalLen; - private int o; - private boolean finished; - - /** - * Create a new {@link InputStream}. - * - * @param in the {@link InputStream} to poll - * @param decompressor the {@link LZ4FastDecompressor decompressor} instance to - * use - * @param checksum the {@link Checksum} instance to use, must be - * equivalent to the instance which has been used to - * write the stream - */ - public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum) { - super(in); - this.decompressor = decompressor; - this.checksum = checksum; - this.buffer = new byte[0]; - this.compressedBuffer = new byte[HEADER_LENGTH]; - o = originalLen = 0; - finished = false; - } - - /** - * Create a new instance using {@link net.jpountz.xxhash.XXHash32} for checksuming. - * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum) - * @see net.jpountz.xxhash.StreamingXXHash32#asChecksum() - */ - public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) { - this(in, decompressor, - XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); - } - - /** - * Create a new instance which uses the fastest {@link LZ4FastDecompressor} available. - * @see LZ4Factory#fastestInstance() - * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor) - */ - public LZ4BlockInputStream(InputStream in) { - this(in, LZ4Factory.fastestInstance().fastDecompressor()); - } - - @Override - public int available() throws IOException { - refill(); - return originalLen - o; - } - - @Override - public int read() throws IOException { - refill(); - if (finished) { - return -1; - } - return buffer[o++] & 0xFF; - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - SafeUtils.checkRange(b, off, len); - refill(); - if (finished) { - return -1; - } - len = Math.min(len, originalLen - o); - System.arraycopy(buffer, o, b, off, len); - o += len; - return len; - } - - @Override - public int read(byte[] b) throws IOException { - return read(b, 0, b.length); - } - - @Override - public long skip(long n) throws IOException { - refill(); - if (finished) { - return -1; - } - final int skipped = (int) Math.min(n, originalLen - o); - o += skipped; - return skipped; - } - - private void refill() throws IOException { - if (finished || o < originalLen) { - return; - } - try { - readFully(compressedBuffer, HEADER_LENGTH); - } catch (EOFException e) { - finished = true; - return; - } - for (int i = 0; i < MAGIC_LENGTH; ++i) { - if (compressedBuffer[i] != MAGIC[i]) { - throw new IOException("Stream is corrupted"); - } - } - final int token = compressedBuffer[MAGIC_LENGTH] & 0xFF; - final int compressionMethod = token & 0xF0; - final int compressionLevel = COMPRESSION_LEVEL_BASE + (token & 0x0F); - if (compressionMethod != COMPRESSION_METHOD_RAW && compressionMethod != COMPRESSION_METHOD_LZ4) - { - throw new IOException("Stream is corrupted"); - } - final int compressedLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 1); - originalLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 5); - final int check = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 9); - assert HEADER_LENGTH == MAGIC_LENGTH + 13; - if (originalLen > 1 << compressionLevel - || originalLen < 0 - || compressedLen < 0 - || (originalLen == 0 && compressedLen != 0) - || (originalLen != 0 && compressedLen == 0) - || (compressionMethod == COMPRESSION_METHOD_RAW && originalLen != compressedLen)) { - throw new IOException("Stream is corrupted"); - } - if (originalLen == 0 && compressedLen == 0) { - if (check != 0) { - throw new IOException("Stream is corrupted"); - } - refill(); - return; - } - if (buffer.length < originalLen) { - buffer = new byte[Math.max(originalLen, buffer.length * 3 / 2)]; - } - switch (compressionMethod) { - case COMPRESSION_METHOD_RAW: - readFully(buffer, originalLen); - break; - case COMPRESSION_METHOD_LZ4: - if (compressedBuffer.length < originalLen) { - compressedBuffer = new byte[Math.max(compressedLen, compressedBuffer.length * 3 / 2)]; - } - readFully(compressedBuffer, compressedLen); - try { - final int compressedLen2 = - decompressor.decompress(compressedBuffer, 0, buffer, 0, originalLen); - if (compressedLen != compressedLen2) { - throw new IOException("Stream is corrupted"); - } - } catch (LZ4Exception e) { - throw new IOException("Stream is corrupted", e); - } - break; - default: - throw new AssertionError(); - } - checksum.reset(); - checksum.update(buffer, 0, originalLen); - if ((int) checksum.getValue() != check) { - throw new IOException("Stream is corrupted"); - } - o = 0; - } - - private void readFully(byte[] b, int len) throws IOException { - int read = 0; - while (read < len) { - final int r = in.read(b, read, len - read); - if (r < 0) { - throw new EOFException("Stream ended prematurely"); - } - read += r; - } - assert len == read; - } - - @Override - public boolean markSupported() { - return false; - } - - @SuppressWarnings("sync-override") - @Override - public void mark(int readlimit) { - // unsupported - } - - @SuppressWarnings("sync-override") - @Override - public void reset() throws IOException { - throw new IOException("mark/reset not supported"); - } - - @Override - public String toString() { - return getClass().getSimpleName() + "(in=" + in - + ", decompressor=" + decompressor + ", checksum=" + checksum + ")"; - } - -} diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index ea5f1a9abf69..f6d1288cb263 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,10 +130,8 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } - //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } - //checkstyle.on: NoFinalizer } diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java new file mode 100644 index 000000000000..5b45d268ace8 --- /dev/null +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -0,0 +1,411 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.apache.spark.util.ThreadUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.InterruptedIOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +/** + * {@link InputStream} implementation which asynchronously reads ahead from the underlying input + * stream when specified amount of data has been read from the current buffer. It does it by + * maintaining two buffers - active buffer and read ahead buffer. Active buffer contains data + * which should be returned when a read() call is issued. The read ahead buffer is used to + * asynchronously read from the underlying input stream and once the current active buffer is + * exhausted, we flip the two buffers so that we can start reading from the read ahead buffer + * without being blocked in disk I/O. + */ +public class ReadAheadInputStream extends InputStream { + + private static final Logger logger = LoggerFactory.getLogger(ReadAheadInputStream.class); + + private ReentrantLock stateChangeLock = new ReentrantLock(); + + @GuardedBy("stateChangeLock") + private ByteBuffer activeBuffer; + + @GuardedBy("stateChangeLock") + private ByteBuffer readAheadBuffer; + + @GuardedBy("stateChangeLock") + private boolean endOfStream; + + @GuardedBy("stateChangeLock") + // true if async read is in progress + private boolean readInProgress; + + @GuardedBy("stateChangeLock") + // true if read is aborted due to an exception in reading from underlying input stream. + private boolean readAborted; + + @GuardedBy("stateChangeLock") + private Throwable readException; + + @GuardedBy("stateChangeLock") + // whether the close method is called. + private boolean isClosed; + + @GuardedBy("stateChangeLock") + // true when the close method will close the underlying input stream. This is valid only if + // `isClosed` is true. + private boolean isUnderlyingInputStreamBeingClosed; + + @GuardedBy("stateChangeLock") + // whether there is a read ahead task running, + private boolean isReading; + + // If the remaining data size in the current buffer is below this threshold, + // we issue an async read from the underlying input stream. + private final int readAheadThresholdInBytes; + + private final InputStream underlyingInputStream; + + private final ExecutorService executorService = + ThreadUtils.newDaemonSingleThreadExecutor("read-ahead"); + + private final Condition asyncReadComplete = stateChangeLock.newCondition(); + + private static final ThreadLocal oneByte = ThreadLocal.withInitial(() -> new byte[1]); + + /** + * Creates a ReadAheadInputStream with the specified buffer size and read-ahead + * threshold + * + * @param inputStream The underlying input stream. + * @param bufferSizeInBytes The buffer size. + * @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead + * threshold, an async read is triggered. + */ + public ReadAheadInputStream( + InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) { + Preconditions.checkArgument(bufferSizeInBytes > 0, + "bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes); + Preconditions.checkArgument(readAheadThresholdInBytes > 0 && + readAheadThresholdInBytes < bufferSizeInBytes, + "readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, " + + "but the value is " + readAheadThresholdInBytes); + activeBuffer = ByteBuffer.allocate(bufferSizeInBytes); + readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes); + this.readAheadThresholdInBytes = readAheadThresholdInBytes; + this.underlyingInputStream = inputStream; + activeBuffer.flip(); + readAheadBuffer.flip(); + } + + private boolean isEndOfStream() { + return (!activeBuffer.hasRemaining() && !readAheadBuffer.hasRemaining() && endOfStream); + } + + private void checkReadException() throws IOException { + if (readAborted) { + Throwables.propagateIfPossible(readException, IOException.class); + throw new IOException(readException); + } + } + + /** Read data from underlyingInputStream to readAheadBuffer asynchronously. */ + private void readAsync() throws IOException { + stateChangeLock.lock(); + final byte[] arr = readAheadBuffer.array(); + try { + if (endOfStream || readInProgress) { + return; + } + checkReadException(); + readAheadBuffer.position(0); + readAheadBuffer.flip(); + readInProgress = true; + } finally { + stateChangeLock.unlock(); + } + executorService.execute(new Runnable() { + + @Override + public void run() { + stateChangeLock.lock(); + try { + if (isClosed) { + readInProgress = false; + return; + } + // Flip this so that the close method will not close the underlying input stream when we + // are reading. + isReading = true; + } finally { + stateChangeLock.unlock(); + } + + // Please note that it is safe to release the lock and read into the read ahead buffer + // because either of following two conditions will hold - 1. The active buffer has + // data available to read so the reader will not read from the read ahead buffer. + // 2. This is the first time read is called or the active buffer is exhausted, + // in that case the reader waits for this async read to complete. + // So there is no race condition in both the situations. + int read = 0; + Throwable exception = null; + try { + while (true) { + read = underlyingInputStream.read(arr); + if (0 != read) break; + } + } catch (Throwable ex) { + exception = ex; + if (ex instanceof Error) { + // `readException` may not be reported to the user. Rethrow Error to make sure at least + // The user can see Error in UncaughtExceptionHandler. + throw (Error) ex; + } + } finally { + stateChangeLock.lock(); + if (read < 0 || (exception instanceof EOFException)) { + endOfStream = true; + } else if (exception != null) { + readAborted = true; + readException = exception; + } else { + readAheadBuffer.limit(read); + } + readInProgress = false; + signalAsyncReadComplete(); + stateChangeLock.unlock(); + closeUnderlyingInputStreamIfNecessary(); + } + } + }); + } + + private void closeUnderlyingInputStreamIfNecessary() { + boolean needToCloseUnderlyingInputStream = false; + stateChangeLock.lock(); + try { + isReading = false; + if (isClosed && !isUnderlyingInputStreamBeingClosed) { + // close method cannot close underlyingInputStream because we were reading. + needToCloseUnderlyingInputStream = true; + } + } finally { + stateChangeLock.unlock(); + } + if (needToCloseUnderlyingInputStream) { + try { + underlyingInputStream.close(); + } catch (IOException e) { + logger.warn(e.getMessage(), e); + } + } + } + + private void signalAsyncReadComplete() { + stateChangeLock.lock(); + try { + asyncReadComplete.signalAll(); + } finally { + stateChangeLock.unlock(); + } + } + + private void waitForAsyncReadComplete() throws IOException { + stateChangeLock.lock(); + try { + while (readInProgress) { + asyncReadComplete.await(); + } + } catch (InterruptedException e) { + InterruptedIOException iio = new InterruptedIOException(e.getMessage()); + iio.initCause(e); + throw iio; + } finally { + stateChangeLock.unlock(); + } + checkReadException(); + } + + @Override + public int read() throws IOException { + byte[] oneByteArray = oneByte.get(); + return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + } + + @Override + public int read(byte[] b, int offset, int len) throws IOException { + if (offset < 0 || len < 0 || len > b.length - offset) { + throw new IndexOutOfBoundsException(); + } + if (len == 0) { + return 0; + } + stateChangeLock.lock(); + try { + return readInternal(b, offset, len); + } finally { + stateChangeLock.unlock(); + } + } + + /** + * flip the active and read ahead buffer + */ + private void swapBuffers() { + ByteBuffer temp = activeBuffer; + activeBuffer = readAheadBuffer; + readAheadBuffer = temp; + } + + /** + * Internal read function which should be called only from read() api. The assumption is that + * the stateChangeLock is already acquired in the caller before calling this function. + */ + private int readInternal(byte[] b, int offset, int len) throws IOException { + assert (stateChangeLock.isLocked()); + if (!activeBuffer.hasRemaining()) { + waitForAsyncReadComplete(); + if (readAheadBuffer.hasRemaining()) { + swapBuffers(); + } else { + // The first read or activeBuffer is skipped. + readAsync(); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return -1; + } + swapBuffers(); + } + } else { + checkReadException(); + } + len = Math.min(len, activeBuffer.remaining()); + activeBuffer.get(b, offset, len); + + if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) { + readAsync(); + } + return len; + } + + @Override + public int available() throws IOException { + stateChangeLock.lock(); + // Make sure we have no integer overflow. + try { + return (int) Math.min((long) Integer.MAX_VALUE, + (long) activeBuffer.remaining() + readAheadBuffer.remaining()); + } finally { + stateChangeLock.unlock(); + } + } + + @Override + public long skip(long n) throws IOException { + if (n <= 0L) { + return 0L; + } + stateChangeLock.lock(); + long skipped; + try { + skipped = skipInternal(n); + } finally { + stateChangeLock.unlock(); + } + return skipped; + } + + /** + * Internal skip function which should be called only from skip() api. The assumption is that + * the stateChangeLock is already acquired in the caller before calling this function. + */ + private long skipInternal(long n) throws IOException { + assert (stateChangeLock.isLocked()); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return 0; + } + if (available() >= n) { + // we can skip from the internal buffers + int toSkip = (int) n; + if (toSkip <= activeBuffer.remaining()) { + // Only skipping from active buffer is sufficient + activeBuffer.position(toSkip + activeBuffer.position()); + if (activeBuffer.remaining() <= readAheadThresholdInBytes + && !readAheadBuffer.hasRemaining()) { + readAsync(); + } + return n; + } + // We need to skip from both active buffer and read ahead buffer + toSkip -= activeBuffer.remaining(); + activeBuffer.position(0); + activeBuffer.flip(); + readAheadBuffer.position(toSkip + readAheadBuffer.position()); + swapBuffers(); + readAsync(); + return n; + } else { + int skippedBytes = available(); + long toSkip = n - skippedBytes; + activeBuffer.position(0); + activeBuffer.flip(); + readAheadBuffer.position(0); + readAheadBuffer.flip(); + long skippedFromInputStream = underlyingInputStream.skip(toSkip); + readAsync(); + return skippedBytes + skippedFromInputStream; + } + } + + @Override + public void close() throws IOException { + boolean isSafeToCloseUnderlyingInputStream = false; + stateChangeLock.lock(); + try { + if (isClosed) { + return; + } + isClosed = true; + if (!isReading) { + // Nobody is reading, so we can close the underlying input stream in this method. + isSafeToCloseUnderlyingInputStream = true; + // Flip this to make sure the read ahead task will not close the underlying input stream. + isUnderlyingInputStreamBeingClosed = true; + } + } finally { + stateChangeLock.unlock(); + } + + try { + executorService.shutdownNow(); + executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS); + } catch (InterruptedException e) { + InterruptedIOException iio = new InterruptedIOException(e.getMessage()); + iio.initCause(e); + throw iio; + } finally { + if (isSafeToCloseUnderlyingInputStream) { + underlyingInputStream.close(); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 48cf4b9455e4..0efae16e9838 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -89,13 +89,7 @@ public LongArray allocateArray(long size) { long required = size * 8L; MemoryBlock page = taskMemoryManager.allocatePage(required, this); if (page == null || page.size() < required) { - long got = 0; - if (page != null) { - got = page.size(); - taskMemoryManager.freePage(page, this); - } - taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throwOom(page, required); } used += required; return new LongArray(page); @@ -111,20 +105,12 @@ public void freeArray(LongArray array) { /** * Allocate a memory block with at least `required` bytes. * - * Throws IOException if there is not enough memory. - * * @throws OutOfMemoryError */ protected MemoryBlock allocatePage(long required) { MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); if (page == null || page.size() < required) { - long got = 0; - if (page != null) { - got = page.size(); - taskMemoryManager.freePage(page, this); - } - taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throwOom(page, required); } used += page.size(); return page; @@ -154,4 +140,14 @@ public void freeMemory(long size) { taskMemoryManager.releaseExecutionMemory(size, this); used -= size; } + + private void throwOom(final MemoryBlock page, final long required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + } } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index aa0b37323132..44b60c1e4e8c 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -19,6 +19,7 @@ import javax.annotation.concurrent.GuardedBy; import java.io.IOException; +import java.nio.channels.ClosedByInterruptException; import java.util.Arrays; import java.util.ArrayList; import java.util.BitSet; @@ -52,8 +53,8 @@ * retrieve the base object. *

* This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the - * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is - * approximately 35 terabytes of memory. + * maximum size of a long[] array, allowing us to address 8192 * (2^31 - 1) * 8 bytes, which is + * approximately 140 terabytes of memory. */ public class TaskMemoryManager { @@ -73,7 +74,8 @@ public class TaskMemoryManager { * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's * maximum page size is limited by the maximum amount of data that can be stored in a long[] - * array, which is (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. + * array, which is (2^31 - 1) * 8 bytes (or about 17 gigabytes). Therefore, we cap this at 17 + * gigabytes. */ public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; @@ -155,7 +157,8 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } @@ -183,6 +186,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { break; } } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + c, e); + throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + c, e); throw new OutOfMemoryError("error while calling spill() on " + c + " : " @@ -200,6 +207,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { Utils.bytesToString(released), consumer); got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + consumer, e); + throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); throw new OutOfMemoryError("error while calling spill() on " + consumer + " : " diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c5283..a9b5236ab817 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -18,9 +18,9 @@ package org.apache.spark.shuffle.sort; import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.IOException; +import java.nio.channels.FileChannel; +import static java.nio.file.StandardOpenOption.*; import javax.annotation.Nullable; import scala.None$; @@ -75,7 +75,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private static final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); private final int fileBufferSize; - private final boolean transferToEnabled; private final int numPartitions; private final BlockManager blockManager; private final Partitioner partitioner; @@ -107,7 +106,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { SparkConf conf) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; @@ -188,17 +186,21 @@ private long[] writePartitionedFile(File outputFile) throws IOException { return lengths; } - final FileOutputStream out = new FileOutputStream(outputFile, true); + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + final FileChannel out = FileChannel.open(outputFile.toPath(), WRITE, APPEND, CREATE); final long writeStartTime = System.nanoTime(); boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); if (file.exists()) { - final FileInputStream in = new FileInputStream(file); + final FileChannel in = FileChannel.open(file.toPath(), READ); boolean copyThrewException = true; try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + long size = in.size(); + Utils.copyFileStreamNIO(in, out, 0, size); + lengths[i] = size; copyThrewException = false; } finally { Closeables.close(in, copyThrewException); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c33d1e33f030..b4f46306f282 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -43,6 +43,7 @@ import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; +import org.apache.spark.internal.config.package$; /** * An external sorter that is specialized for sort-based shuffle. @@ -82,6 +83,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; + /** The buffer size to use when writing the sorted records to an on-disk file */ + private final int diskWriteBufferSize; + /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -116,13 +120,16 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.taskContext = taskContext; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.fileBufferSizeBytes = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); this.peakMemoryUsedBytes = getMemoryUsage(); + this.diskWriteBufferSize = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); } /** @@ -133,7 +140,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { * bytes written should be counted towards shuffle spill metrics rather than * shuffle write metrics. */ - private void writeSortedFile(boolean isLastFile) throws IOException { + private void writeSortedFile(boolean isLastFile) { final ShuffleWriteMetrics writeMetricsToUse; @@ -155,7 +162,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. This array does not need to be large enough to hold a single // record; - final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + final byte[] writeBuffer = new byte[diskWriteBufferSize]; // Because this output will be read during shuffle, its compression codec must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -195,7 +202,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { - final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); Platform.copyMemory( recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); @@ -318,7 +325,7 @@ public void cleanupResources() { * array and grows the array if additional space is required. If the required space cannot be * obtained, then the in-memory data will be spilled to disk. */ - private void growPointerArrayIfNecessary() throws IOException { + private void growPointerArrayIfNecessary() { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); @@ -399,19 +406,14 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p * @throws IOException */ public SpillInfo[] closeAndGetSpills() throws IOException { - try { - if (inMemSorter != null) { - // Do not count the final file towards the spill count. - writeSortedFile(true); - freeMemory(); - inMemSorter.free(); - inMemSorter = null; - } - return spills.toArray(new SpillInfo[spills.size()]); - } catch (IOException e) { - cleanupResources(); - throw e; + if (inMemSorter != null) { + // Do not count the final file towards the spill count. + writeSortedFile(true); + freeMemory(); + inMemSorter.free(); + inMemSorter = null; } + return spills.toArray(new SpillInfo[spills.size()]); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 8a1771848dee..e9c2a69c47cb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -20,6 +20,7 @@ import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import static java.nio.file.StandardOpenOption.*; import java.util.Iterator; import scala.Option; @@ -40,6 +41,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; @@ -54,6 +56,7 @@ import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; +import org.apache.spark.internal.config.package$; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -64,6 +67,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @VisibleForTesting static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; + static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; private final IndexShuffleBlockResolver shuffleBlockResolver; @@ -77,6 +81,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; private final int initialSortBufferSize; + private final int inputBufferSizeInBytes; + private final int outputBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -98,6 +104,18 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream */ private boolean stopping = false; + private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { + + CloseAndFlushShieldOutputStream(OutputStream outputStream) { + super(outputStream); + } + + @Override + public void flush() { + // do nothing + } + } + public UnsafeShuffleWriter( BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, @@ -127,6 +145,10 @@ public UnsafeShuffleWriter( this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE); + this.inputBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.outputBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); } @@ -186,7 +208,7 @@ public void write(scala.collection.Iterator> records) throws IOEx } } - private void open() throws IOException { + private void open() { assert (sorter == null); sorter = new ShuffleExternalSorter( memoryManager, @@ -196,7 +218,7 @@ private void open() throws IOException { partitioner.numPartitions(), sparkConf, writeMetrics); - serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); serOutputStream = serializer.serializeStream(serBuffer); } @@ -269,7 +291,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { - new FileOutputStream(outputFile).close(); // Create an empty file + java.nio.file.Files.newOutputStream(outputFile.toPath()).close(); // Create an empty file return new long[partitioner.numPartitions()]; } else if (spills.length == 1) { // Here, we don't need to perform any metrics updates because the bytes written to this @@ -321,11 +343,15 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti } /** - * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, - * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in - * cases where the IO compression codec does not support concatenation of compressed data, when - * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in - * order to work around kernel bugs. + * Merges spill files using Java FileStreams. This code path is typically slower than + * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], + * File)}, and it's mostly used in cases where the IO compression codec does not support + * concatenation of compressed data, when encryption is enabled, or when users have + * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. + * This code path might also be faster in cases where individual partition size in a spill + * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small + * disk ios which is inefficient. In those case, Using large buffers for input and output + * files helps reducing the number of disk ios, making the file merging faster. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to. @@ -339,23 +365,28 @@ private long[] mergeSpillsWithFileStream( assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; - final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + final InputStream[] spillInputStreams = new InputStream[spills.length]; + final OutputStream bos = new BufferedOutputStream( + java.nio.file.Files.newOutputStream(outputFile.toPath()), + outputBufferSizeInBytes); // Use a counting output stream to avoid having to close the underlying file and ask // the file system for its size after each partition is written. - final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( - new FileOutputStream(outputFile)); + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { - spillInputStreams[i] = new FileInputStream(spills[i].file); + spillInputStreams[i] = new NioBufferedFileInputStream( + spills[i].file, + inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() calls, so that we can close the higher - // level streams to make sure all data is really flushed and internal state is cleaned. - OutputStream partitionOutput = new CloseShieldOutputStream( + // Shield the underlying output stream from close() and flush() calls, so that we can close + // the higher level streams to make sure all data is really flushed and internal state is + // cleaned. + OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { @@ -412,27 +443,24 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { - spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + spillInputChannels[i] = FileChannel.open(spills[i].file.toPath(), READ); } // This file needs to opened in append mode in order to work around a Linux kernel bug that // affects transferTo; see SPARK-3948 for more details. - mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + mergedFileOutputChannel = FileChannel.open(outputFile.toPath(), WRITE, CREATE, APPEND); long bytesWrittenToMergedFile = 0; for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - long bytesToTransfer = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - while (bytesToTransfer > 0) { - final long actualBytesTransferred = spillInputChannel.transferTo( - spillInputChannelPositions[i], - bytesToTransfer, - mergedFileOutputChannel); - spillInputChannelPositions[i] += actualBytesTransferred; - bytesToTransfer -= actualBytesTransferred; - } + Utils.copyFileStreamNIO( + spillInputChannel, + mergedFileOutputChannel, + spillInputChannelPositions[i], + partitionLengthInSpill); + spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 4bef21b6b4e4..4fadfe36cd71 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -160,14 +160,10 @@ public final class BytesToBytesMap extends MemoryConsumer { private final boolean enablePerfMetrics; - private long timeSpentResizingNs = 0; - private long numProbes = 0; private long numKeyLookups = 0; - private long numHashCollisions = 0; - private long peakMemoryUsedBytes = 0L; private final int initialCapacity; @@ -262,6 +258,11 @@ private MapIterator(int numRecords, Location loc, boolean destructive) { this.destructive = destructive; if (destructive) { destructiveIterator = this; + // longArray will not be used anymore if destructive is true, release it now. + if (longArray != null) { + freeArray(longArray); + longArray = null; + } } } @@ -282,13 +283,7 @@ private void advanceToNextPage() { } else { currentPage = null; if (reader != null) { - // remove the spill file from disk - File file = spillWriters.removeFirst().getFile(); - if (file != null && file.exists()) { - if (!file.delete()) { - logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - } - } + handleFailedDelete(); } try { Closeables.close(reader, /* swallowIOException = */ false); @@ -306,13 +301,7 @@ private void advanceToNextPage() { public boolean hasNext() { if (numRecords == 0) { if (reader != null) { - // remove the spill file from disk - File file = spillWriters.removeFirst().getFile(); - if (file != null && file.exists()) { - if (!file.delete()) { - logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - } - } + handleFailedDelete(); } } return numRecords > 0; @@ -402,6 +391,14 @@ public long spill(long numBytes) throws IOException { public void remove() { throw new UnsupportedOperationException(); } + + private void handleFailedDelete() { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists() && !file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } } /** @@ -489,10 +486,6 @@ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location l ); if (areEqual) { return; - } else { - if (enablePerfMetrics) { - numHashCollisions++; - } } } } @@ -859,16 +852,6 @@ public long getPeakMemoryUsedBytes() { return peakMemoryUsedBytes; } - /** - * Returns the total amount of time spent resizing this map (in nanoseconds). - */ - public long getTimeSpentResizingNs() { - if (!enablePerfMetrics) { - throw new IllegalStateException(); - } - return timeSpentResizingNs; - } - /** * Returns the average number of probes per key lookup. */ @@ -879,13 +862,6 @@ public double getAverageProbesPerLookup() { return (1.0 * numProbes) / numKeyLookups; } - public long getNumHashCollisions() { - if (!enablePerfMetrics) { - throw new IllegalStateException(); - } - return numHashCollisions; - } - @VisibleForTesting public int getNumDataPages() { return dataPages.size(); @@ -923,10 +899,6 @@ public void reset() { void growAndRehash() { assert(longArray != null); - long resizeStartTime = -1; - if (enablePerfMetrics) { - resizeStartTime = System.nanoTime(); - } // Store references to the old data structures to be used when we re-hash final LongArray oldLongArray = longArray; final int oldCapacity = (int) oldLongArray.size() / 2; @@ -951,9 +923,5 @@ void growAndRehash() { longArray.set(newPos * 2 + 1, hashcode); } freeArray(oldLongArray); - - if (enablePerfMetrics) { - timeSpentResizingNs += System.nanoTime() - resizeStartTime; - } } } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java index 20654e4eeaa0..b8c2294c7b7a 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -30,11 +30,17 @@ public interface HashMapGrowthStrategy { HashMapGrowthStrategy DOUBLING = new Doubling(); class Doubling implements HashMapGrowthStrategy { + + // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat + // smaller. Be conservative and lower the cap a little. + private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + @Override public int nextCapacity(int currentCapacity) { assert (currentCapacity > 0); + int doubleCapacity = currentCapacity * 2; // Guard against overflow - return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE; + return (doubleCapacity > 0 && doubleCapacity <= ARRAY_MAX) ? doubleCapacity : ARRAY_MAX; } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index f312fa2b2ddd..39eda00dd7ef 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.LinkedList; import java.util.Queue; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -48,13 +49,20 @@ public final class UnsafeExternalSorter extends MemoryConsumer { @Nullable private final PrefixComparator prefixComparator; + + /** + * {@link RecordComparator} may probably keep the reference to the records they compared last + * time, so we should not keep a {@link RecordComparator} instance inside + * {@link UnsafeExternalSorter}, because {@link UnsafeExternalSorter} is referenced by + * {@link TaskContext} and thus can not be garbage collected until the end of the task. + */ @Nullable - private final RecordComparator recordComparator; + private final Supplier recordComparatorSupplier; + private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final SerializerManager serializerManager; private final TaskContext taskContext; - private ShuffleWriteMetrics writeMetrics; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -91,14 +99,14 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, long numElementsForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, - serializerManager, taskContext, recordComparator, prefixComparator, initialSize, + serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. @@ -111,14 +119,14 @@ public static UnsafeExternalSorter create( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, long numElementsForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, + taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, numElementsForSpillThreshold, null, canUseRadixSort); } @@ -127,7 +135,7 @@ private UnsafeExternalSorter( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, @@ -139,19 +147,24 @@ private UnsafeExternalSorter( this.blockManager = blockManager; this.serializerManager = serializerManager; this.taskContext = taskContext; - this.recordComparator = recordComparator; + this.recordComparatorSupplier = recordComparatorSupplier; this.prefixComparator = prefixComparator; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; - // The spill metrics are stored in a new ShuffleWriteMetrics, - // and then discarded (this fixes SPARK-16827). - // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). - this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { + RecordComparator comparator = null; + if (recordComparatorSupplier != null) { + comparator = recordComparatorSupplier.get(); + } this.inMemSorter = new UnsafeInMemorySorter( - this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort); + this, + taskMemoryManager, + comparator, + prefixComparator, + initialSize, + canUseRadixSort); } else { this.inMemSorter = existingInMemorySorter; } @@ -199,21 +212,14 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { spillWriters.size(), spillWriters.size() > 1 ? " times" : " time"); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // We only write out contents of the inMemSorter if it is not empty. if (inMemSorter.numRecords() > 0) { final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, inMemSorter.numRecords()); spillWriters.add(spillWriter); - final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final Object baseObject = sortedRecords.getBaseObject(); - final long baseOffset = sortedRecords.getBaseOffset(); - final int recordLength = sortedRecords.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); - } - spillWriter.close(); + spillIterator(inMemSorter.getSortedIterator(), spillWriter); } final long spillSize = freeMemory(); @@ -226,6 +232,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // pages, we might not be able to get memory for the pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += spillSize; return spillSize; } @@ -331,7 +338,7 @@ public void cleanupResources() { * array and grows the array if additional space is required. If the required space cannot be * obtained, then the in-memory data will be spilled to disk. */ - private void growPointerArrayIfNecessary() throws IOException { + private void growPointerArrayIfNecessary() { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); @@ -454,14 +461,14 @@ public void merge(UnsafeExternalSorter other) throws IOException { * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { - assert(recordComparator != null); + assert(recordComparatorSupplier != null); if (spillWriters.isEmpty()) { assert(inMemSorter != null); readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); return readingIterator; } else { - final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); + final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger( + recordComparatorSupplier.get(), prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); } @@ -473,6 +480,18 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { } } + private static void spillIterator(UnsafeSorterIterator inMemIterator, + UnsafeSorterSpillWriter spillWriter) throws IOException { + while (inMemIterator.hasNext()) { + inMemIterator.loadNext(); + final Object baseObject = inMemIterator.getBaseObject(); + final long baseOffset = inMemIterator.getBaseOffset(); + final int recordLength = inMemIterator.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); + } + spillWriter.close(); + } + /** * An UnsafeSorterIterator that support spilling. */ @@ -488,6 +507,7 @@ class SpillableIterator extends UnsafeSorterIterator { this.numRecords = inMemIterator.getNumRecords(); } + @Override public int getNumRecords() { return numRecords; } @@ -502,17 +522,11 @@ public long spill() throws IOException { UnsafeInMemorySorter.SortedIterator inMemIterator = ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); - while (inMemIterator.hasNext()) { - inMemIterator.loadNext(); - final Object baseObject = inMemIterator.getBaseObject(); - final long baseOffset = inMemIterator.getBaseOffset(); - final int recordLength = inMemIterator.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); - } - spillWriter.close(); + spillIterator(inMemIterator, spillWriter); spillWriters.add(spillWriter); nextUpstream = spillWriter.getReader(serializerManager); @@ -540,6 +554,7 @@ public long spill() throws IOException { inMemSorter.free(); inMemSorter = null; taskContext.taskMetrics().incMemoryBytesSpilled(released); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += released; return released; } @@ -590,29 +605,54 @@ public long getKeyPrefix() { } /** - * Returns a iterator, which will return the rows in the order as inserted. + * Returns an iterator starts from startIndex, which will return the rows in the order as + * inserted. * * It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. * * TODO: support forced spilling */ - public UnsafeSorterIterator getIterator() throws IOException { + public UnsafeSorterIterator getIterator(int startIndex) throws IOException { if (spillWriters.isEmpty()) { assert(inMemSorter != null); - return inMemSorter.getSortedIterator(); + UnsafeSorterIterator iter = inMemSorter.getSortedIterator(); + moveOver(iter, startIndex); + return iter; } else { LinkedList queue = new LinkedList<>(); + int i = 0; for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - queue.add(spillWriter.getReader(serializerManager)); + if (i + spillWriter.recordsSpilled() > startIndex) { + UnsafeSorterIterator iter = spillWriter.getReader(serializerManager); + moveOver(iter, startIndex - i); + queue.add(iter); + } + i += spillWriter.recordsSpilled(); } if (inMemSorter != null) { - queue.add(inMemSorter.getSortedIterator()); + UnsafeSorterIterator iter = inMemSorter.getSortedIterator(); + moveOver(iter, startIndex - i); + queue.add(iter); } return new ChainedIterator(queue); } } + private void moveOver(UnsafeSorterIterator iter, int steps) + throws IOException { + if (steps > 0) { + for (int i = 0; i < steps; i++) { + if (iter.hasNext()) { + iter.loadNext(); + } else { + throw new ArrayIndexOutOfBoundsException("Failed to move the iterator " + steps + + " steps forward"); + } + } + } + } + /** * Chain multiple UnsafeSorterIterator together as single one. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 9521ab86a12d..e2f48e5508af 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -17,20 +17,20 @@ package org.apache.spark.util.collection.unsafe.sort; -import java.io.*; - import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; - import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.io.NioBufferedFileInputStream; +import org.apache.spark.io.ReadAheadInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.*; + /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). @@ -72,10 +72,22 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } + final double readAheadFraction = + SparkEnv.get() == null ? 0.5 : + SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); + + final boolean readAheadEnabled = SparkEnv.get() != null && + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); + final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); try { - this.in = serializerManager.wrapStream(blockId, bs); + if (readAheadEnabled) { + this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), + (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + } else { + this.in = serializerManager.wrapStream(blockId, bs); + } this.din = new DataInputStream(this.in); numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 164b9d70b79d..9399024f0178 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -20,9 +20,10 @@ import java.io.File; import java.io.IOException; -import org.apache.spark.serializer.SerializerManager; import scala.Tuple2; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.storage.BlockId; @@ -30,6 +31,7 @@ import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.internal.config.package$; /** * Spills a list of sorted records to disk. Spill files have the following format: @@ -38,12 +40,16 @@ */ public final class UnsafeSorterSpillWriter { - static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + private final SparkConf conf = new SparkConf(); + + /** The buffer size to use when writing the sorted records to an on-disk file */ + private final int diskWriteBufferSize = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. - private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + private byte[] writeBuffer = new byte[diskWriteBufferSize]; private final File file; private final BlockId blockId; @@ -73,7 +79,7 @@ public UnsafeSorterSpillWriter( } // Based on DataOutputStream.writeLong. - private void writeLongToBuffer(long v, int offset) throws IOException { + private void writeLongToBuffer(long v, int offset) { writeBuffer[offset + 0] = (byte)(v >>> 56); writeBuffer[offset + 1] = (byte)(v >>> 48); writeBuffer[offset + 2] = (byte)(v >>> 40); @@ -85,7 +91,7 @@ private void writeLongToBuffer(long v, int offset) throws IOException { } // Based on DataOutputStream.writeInt. - private void writeIntToBuffer(int v, int offset) throws IOException { + private void writeIntToBuffer(int v, int offset) { writeBuffer[offset + 0] = (byte)(v >>> 24); writeBuffer[offset + 1] = (byte)(v >>> 16); writeBuffer[offset + 2] = (byte)(v >>> 8); @@ -114,7 +120,7 @@ public void write( writeIntToBuffer(recordLength, 0); writeLongToBuffer(keyPrefix, 4); int dataRemaining = recordLength; - int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + int freeSpaceInWriteBuffer = diskWriteBufferSize - 4 - 8; // space used by prefix + len long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); @@ -122,15 +128,15 @@ public void write( baseObject, recordReadPosition, writeBuffer, - Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + Platform.BYTE_ARRAY_OFFSET + (diskWriteBufferSize - freeSpaceInWriteBuffer), toTransfer); - writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + writer.write(writeBuffer, 0, (diskWriteBufferSize - freeSpaceInWriteBuffer) + toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; - freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + freeSpaceInWriteBuffer = diskWriteBufferSize; } - if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { - writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + if (freeSpaceInWriteBuffer < diskWriteBufferSize) { + writer.write(writeBuffer, 0, (diskWriteBufferSize - freeSpaceInWriteBuffer)); } writer.recordWritten(); } @@ -149,4 +155,8 @@ public File getFile() { public UnsafeSorterSpillReader getReader(SerializerManager serializerManager) throws IOException { return new UnsafeSorterSpillReader(serializerManager, file, blockId); } + + public int recordsSpilled() { + return numRecordsSpilled; + } } diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index cb9922d23c44..d430d8c5fb35 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -26,7 +26,6 @@ function getThreadDumpEnabled() { } function formatStatus(status, type) { - if (type !== 'display') return status; if (status) { return "Active" } else { @@ -417,7 +416,6 @@ $(document).ready(function () { }, {data: 'hostPort'}, {data: 'isActive', render: function (data, type, row) { - if (type !== 'display') return data; if (row.isBlacklisted) return "Blacklisted"; else return formatStatus (data, type); } @@ -492,24 +490,20 @@ $(document).ready(function () { {data: 'totalInputBytes', render: formatBytes}, {data: 'totalShuffleRead', render: formatBytes}, {data: 'totalShuffleWrite', render: formatBytes}, - {data: 'executorLogs', render: formatLogsCells}, + {name: 'executorLogsCol', data: 'executorLogs', render: formatLogsCells}, { + name: 'threadDumpCol', data: 'id', render: function (data, type) { return type === 'display' ? ("Thread Dump" ) : data; } } ], - "columnDefs": [ - { - "targets": [ 16 ], - "visible": getThreadDumpEnabled() - } - ], "order": [[0, "asc"]] }; var dt = $(selector).DataTable(conf); - dt.column(15).visible(logsExist(response)); + dt.column('executorLogsCol:name').visible(logsExist(response)); + dt.column('threadDumpCol:name').visible(getThreadDumpEnabled()); $('#active-executors [data-toggle="tooltip"]').tooltip(); var sumSelector = "#summary-execs-table"; diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 6ba3b092dc65..18d921ab67be 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -20,47 +20,51 @@ - + App ID - + App Name - - + {{#hasMultipleAttempts}} + + Attempt ID + {{/hasMultipleAttempts}} - + Started + {{#showCompletedColumns}} - + Completed - + Duration + {{/showCompletedColumns}} - + Spark User - + Last Updated - + Event Log @@ -68,13 +72,17 @@ {{#applications}} - {{id}} - {{name}} + {{id}} + {{name}} {{#attempts}} - {{attemptId}} + {{#hasMultipleAttempts}} + {{attemptId}} + {{/hasMultipleAttempts}} {{startTime}} + {{#showCompletedColumns}} {{endTime}} - {{duration}} + {{duration}} + {{/showCompletedColumns}} {{sparkUser}} {{lastUpdated}} Download diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 1f89306403cd..aa7e86037255 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -48,6 +48,18 @@ function getParameterByName(name, searchString) { return results === null ? "" : decodeURIComponent(results[1].replace(/\+/g, " ")); } +function removeColumnByName(columns, columnName) { + return columns.filter(function(col) {return col.name != columnName}) +} + +function getColumnIndex(columns, columnName) { + for(var i = 0; i < columns.length; i++) { + if (columns[i].name == columnName) + return i; + } + return -1; +} + jQuery.extend( jQuery.fn.dataTableExt.oSort, { "title-numeric-pre": function ( a ) { var x = a.match(/title="*(-?[0-9\.]+)/)[1]; @@ -122,73 +134,69 @@ $(document).ready(function() { attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; - + attempt["durationMillisec"] = attempt["duration"]; + attempt["duration"] = formatDuration(attempt["duration"]); var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; array.push(app_clone); } } + if(array.length < 20) { + $.fn.dataTable.defaults.paging = false; + } var data = { "uiroot": uiRoot, - "applications": array - } + "applications": array, + "hasMultipleAttempts": hasMultipleAttempts, + "showCompletedColumns": !requestedIncomplete, + } $.get("static/historypage-template.html", function(template) { - historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); - var selector = "#history-summary-table"; + var sibling = historySummary.prev(); + historySummary.detach(); + var apps = $(Mustache.render($(template).filter("#history-summary-template").html(),data)); + var attemptIdColumnName = 'attemptId'; + var startedColumnName = 'started'; + var defaultSortColumn = completedColumnName = 'completed'; + var durationColumnName = 'duration'; var conf = { - "columns": [ - {name: 'first', type: "appid-numeric"}, - {name: 'second'}, - {name: 'third'}, - {name: 'fourth'}, - {name: 'fifth'}, - {name: 'sixth', type: "title-numeric"}, - {name: 'seventh'}, - {name: 'eighth'}, - {name: 'ninth'}, - ], - "columnDefs": [ - {"searchable": false, "targets": [5]} - ], - "autoWidth": false, - "order": [[ 4, "desc" ]] - }; - - var rowGroupConf = { - "rowsGroup": [ - 'first:name', - 'second:name' - ], + "columns": [ + {name: 'appId', type: "appid-numeric"}, + {name: 'appName'}, + {name: attemptIdColumnName}, + {name: startedColumnName}, + {name: completedColumnName}, + {name: durationColumnName, type: "title-numeric"}, + {name: 'user'}, + {name: 'lastUpdated'}, + {name: 'eventLog'}, + ], + "autoWidth": false, }; if (hasMultipleAttempts) { - jQuery.extend(conf, rowGroupConf); - var rowGroupCells = document.getElementsByClassName("rowGroupColumn"); - for (i = 0; i < rowGroupCells.length; i++) { - rowGroupCells[i].style='background-color: #ffffff'; - } - } - - if (!hasMultipleAttempts) { - var attemptIDCells = document.getElementsByClassName("attemptIDSpan"); - for (i = 0; i < attemptIDCells.length; i++) { - attemptIDCells[i].style.display='none'; - } - } - - var durationCells = document.getElementsByClassName("durationClass"); - for (i = 0; i < durationCells.length; i++) { - var timeInMilliseconds = parseInt(durationCells[i].title); - durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + conf.rowsGroup = [ + 'appId:name', + 'appName:name' + ]; + } else { + conf.columns = removeColumnByName(conf.columns, attemptIdColumnName); } - if ($(selector.concat(" tr")).length < 20) { - $.extend(conf, {paging: false}); + var defaultSortColumn = completedColumnName; + if (requestedIncomplete) { + defaultSortColumn = startedColumnName; + conf.columns = removeColumnByName(conf.columns, completedColumnName); + conf.columns = removeColumnByName(conf.columns, durationColumnName); } - - $(selector).DataTable(conf); - $('#hisotry-summary [data-toggle="tooltip"]').tooltip(); + conf.order = [[ getColumnIndex(conf.columns, defaultSortColumn), "desc" ]]; + conf.columnDefs = [ + {"searchable": false, "targets": [getColumnIndex(conf.columns, durationColumnName)]} + ]; + historySummary.append(apps); + apps.DataTable(conf); + sibling.after(historySummary); + $('#history-summary [data-toggle="tooltip"]').tooltip(); }); }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index ff241470f32d..9960d5c34d1f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -207,8 +207,8 @@ sorttable = { hasInputs = (typeof node.getElementsByTagName == 'function') && node.getElementsByTagName('input').length; - - if (node.getAttribute("sorttable_customkey") != null) { + + if (node.nodeType == 1 && node.getAttribute("sorttable_customkey") != null) { return node.getAttribute("sorttable_customkey"); } else if (typeof node.textContent != 'undefined' && !hasInputs) { diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala index 5532931e2a79..3092074232d1 100644 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -201,7 +201,8 @@ trait AccumulableParam[R, T] extends Serializable { @deprecated("use AccumulatorV2", "2.0.0") private[spark] class -GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] +GrowableAccumulableParam[R : ClassTag, T] + (implicit rg: R => Growable[T] with TraversableOnce[T] with Serializable) extends AccumulableParam[R, T] { def addAccumulator(growable: R, elem: T): R = { diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index fcc72ff49276..119b426a9af3 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.control.ControlThrowable +import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} @@ -217,7 +217,7 @@ private[spark] class ExecutorAllocationManager( * the scheduling task. */ def start(): Unit = { - listenerBus.addListener(listener) + listenerBus.addToManagementQueue(listener) val scheduleTask = new Runnable() { override def run(): Unit = { @@ -245,14 +245,15 @@ private[spark] class ExecutorAllocationManager( } /** - * Reset the allocation manager to the initial state. Currently this will only be called in - * yarn-client mode when AM re-registers after a failure. + * Reset the allocation manager when the cluster manager loses track of the driver's state. + * This is currently only done in YARN client mode, when the AM is restarted. + * + * This method forgets about any state about existing executors, and forces the scheduler to + * re-evaluate the number of needed executors the next time it's run. */ def reset(): Unit = synchronized { - initializing = true + addTime = 0L numExecutorsTarget = initialNumExecutors - numExecutorsToAdd = 1 - executorsPendingToRemove.clear() removeTimes.clear() } @@ -372,12 +373,27 @@ private[spark] class ExecutorAllocationManager( // If our target has not changed, do not send a message // to the cluster manager and reset our exponential growth if (delta == 0) { - numExecutorsToAdd = 1 - return 0 + // Check if there is any speculative jobs pending + if (listener.pendingTasks == 0 && listener.pendingSpeculativeTasks > 0) { + numExecutorsTarget = + math.max(math.min(maxNumExecutorsNeeded + 1, maxNumExecutors), minNumExecutors) + } else { + numExecutorsToAdd = 1 + return 0 + } } - val addRequestAcknowledged = testing || - client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) + val addRequestAcknowledged = try { + testing || + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) + } catch { + case NonFatal(e) => + // Use INFO level so the error it doesn't show up by default in shells. Errors here are more + // commonly caused by YARN AM restarts, which is a recoverable issue, and generate a lot of + // noisy output. + logInfo("Error reaching cluster manager.", e) + false + } if (addRequestAcknowledged) { val executorsString = "executor" + { if (delta > 1) "s" else "" } logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + @@ -410,7 +426,10 @@ private[spark] class ExecutorAllocationManager( executors.foreach { executorIdToBeRemoved => if (newExecutorTotal - 1 < minNumExecutors) { logDebug(s"Not removing idle executor $executorIdToBeRemoved because there are only " + - s"$newExecutorTotal executor(s) left (limit $minNumExecutors)") + s"$newExecutorTotal executor(s) left (minimum number of executor limit $minNumExecutors)") + } else if (newExecutorTotal - 1 < numExecutorsTarget) { + logDebug(s"Not removing idle executor $executorIdToBeRemoved because there are only " + + s"$newExecutorTotal executor(s) left (number of executor target $numExecutorsTarget)") } else if (canBeKilled(executorIdToBeRemoved)) { executorIdsToBeRemoved += executorIdToBeRemoved newExecutorTotal -= 1 @@ -427,6 +446,9 @@ private[spark] class ExecutorAllocationManager( } else { client.killExecutors(executorIdsToBeRemoved) } + // [SPARK-21834] killExecutors api reduces the target number of executors. + // So we need to update the target with desired value. + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) // reset the newExecutorTotal to the existing number of executors newExecutorTotal = numExistingExecutors if (testing || executorsRemoved.nonEmpty) { @@ -575,17 +597,22 @@ private[spark] class ExecutorAllocationManager( * A listener that notifies the given allocation manager of when to add and remove executors. * * This class is intentionally conservative in its assumptions about the relative ordering - * and consistency of events returned by the listener. For simplicity, it does not account - * for speculated tasks. + * and consistency of events returned by the listener. */ private class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] - // Number of tasks currently running on the cluster. Should be 0 when no stages are active. + // Number of tasks currently running on the cluster including speculative tasks. + // Should be 0 when no stages are active. private var numRunningTasks: Int = _ + // Number of speculative tasks to be scheduled in each stage + private val stageIdToNumSpeculativeTasks = new mutable.HashMap[Int, Int] + // The speculative tasks started in each stage + private val stageIdToSpeculativeTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] + // stageId to tuple (the number of task with locality preferences, a map where each pair is a // node and the number of tasks that would like to be scheduled on that node) map, // maintain the executor placement hints for each stage Id used by resource framework to better @@ -624,7 +651,9 @@ private[spark] class ExecutorAllocationManager( val stageId = stageCompleted.stageInfo.stageId allocationManager.synchronized { stageIdToNumTasks -= stageId + stageIdToNumSpeculativeTasks -= stageId stageIdToTaskIndices -= stageId + stageIdToSpeculativeTaskIndices -= stageId stageIdToExecutorPlacementHints -= stageId // Update the executor placement hints @@ -632,7 +661,7 @@ private[spark] class ExecutorAllocationManager( // If this is the last stage with pending tasks, mark the scheduler queue as empty // This is needed in case the stage is aborted for any reason - if (stageIdToNumTasks.isEmpty) { + if (stageIdToNumTasks.isEmpty && stageIdToNumSpeculativeTasks.isEmpty) { allocationManager.onSchedulerQueueEmpty() if (numRunningTasks != 0) { logWarning("No stages are running, but numRunningTasks != 0") @@ -658,7 +687,12 @@ private[spark] class ExecutorAllocationManager( } // If this is the last pending task, mark the scheduler queue as empty - stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex + if (taskStart.taskInfo.speculative) { + stageIdToSpeculativeTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += + taskIndex + } else { + stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex + } if (totalPendingTasks() == 0) { allocationManager.onSchedulerQueueEmpty() } @@ -692,7 +726,11 @@ private[spark] class ExecutorAllocationManager( if (totalPendingTasks() == 0) { allocationManager.onSchedulerBacklogged() } - stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + if (taskEnd.taskInfo.speculative) { + stageIdToSpeculativeTaskIndices.get(stageId).foreach {_.remove(taskIndex)} + } else { + stageIdToTaskIndices.get(stageId).foreach {_.remove(taskIndex)} + } } } } @@ -713,18 +751,39 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorRemoved(executorRemoved.executorId) } + override def onSpeculativeTaskSubmitted(speculativeTask: SparkListenerSpeculativeTaskSubmitted) + : Unit = { + val stageId = speculativeTask.stageId + + allocationManager.synchronized { + stageIdToNumSpeculativeTasks(stageId) = + stageIdToNumSpeculativeTasks.getOrElse(stageId, 0) + 1 + allocationManager.onSchedulerBacklogged() + } + } + /** * An estimate of the total number of pending tasks remaining for currently running stages. Does * not account for tasks which may have failed and been resubmitted. * * Note: This is not thread-safe without the caller owning the `allocationManager` lock. */ - def totalPendingTasks(): Int = { + def pendingTasks(): Int = { stageIdToNumTasks.map { case (stageId, numTasks) => numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0) }.sum } + def pendingSpeculativeTasks(): Int = { + stageIdToNumSpeculativeTasks.map { case (stageId, numTasks) => + numTasks - stageIdToSpeculativeTaskIndices.get(stageId).map(_.size).getOrElse(0) + }.sum + } + + def totalPendingTasks(): Int = { + pendingTasks + pendingSpeculativeTasks + } + /** * The number of tasks currently running across all stages. */ diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index a50600f1488c..1034fdcae8e8 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -89,6 +89,14 @@ trait FutureAction[T] extends Future[T] { */ override def value: Option[Try[T]] + // These two methods must be implemented in Scala 2.12, but won't be used by Spark + + def transform[S](f: (Try[T]) => Try[S])(implicit executor: ExecutionContext): Future[S] = + throw new UnsupportedOperationException() + + def transformWith[S](f: (Try[T]) => Future[S])(implicit executor: ExecutionContext): Future[S] = + throw new UnsupportedOperationException() + /** * Blocks and returns the result of this job. */ @@ -261,7 +269,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S private def getImpl(timeout: Duration): T = { // This will throw TimeoutException on timeout: - Await.ready(futureAction, timeout) + ThreadUtils.awaitReady(futureAction, timeout) futureAction.value.get match { case scala.util.Success(value) => converter(value) case scala.util.Failure(exception) => diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 5242ab6f5523..ff960b396dbf 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -63,7 +63,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) this(sc, new SystemClock) } - sc.addSparkListener(this) + sc.listenerBus.addToManagementQueue(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 82d3098e2e05..18b10d23da94 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -50,6 +50,7 @@ private[spark] object InternalAccumulator { val REMOTE_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "remoteBlocksFetched" val LOCAL_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "localBlocksFetched" val REMOTE_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesRead" + val REMOTE_BYTES_READ_TO_DISK = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesReadToDisk" val LOCAL_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "localBytesRead" val FETCH_WAIT_TIME = SHUFFLE_READ_METRICS_PREFIX + "fetchWaitTime" val RECORDS_READ = SHUFFLE_READ_METRICS_PREFIX + "recordsRead" diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4ef665622245..7f760a59bda2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -34,6 +34,178 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ +/** + * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single + * ShuffleMapStage. + * + * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of + * serialized map statuses in order to speed up tasks' requests for map output statuses. + * + * All public methods of this class are thread-safe. + */ +private class ShuffleStatus(numPartitions: Int) { + + // All accesses to the following state must be guarded with `this.synchronized`. + + /** + * MapStatus for each partition. The index of the array is the map partition id. + * Each value in the array is the MapStatus for a partition, or null if the partition + * is not available. Even though in theory a task may run multiple times (due to speculation, + * stage retries, etc.), in practice the likelihood of a map output being available at multiple + * locations is so small that we choose to ignore that case and store only a single location + * for each output. + */ + // Exposed for testing + val mapStatuses = new Array[MapStatus](numPartitions) + + /** + * The cached result of serializing the map statuses array. This cache is lazily populated when + * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. + */ + private[this] var cachedSerializedMapStatus: Array[Byte] = _ + + /** + * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] + * serializes the map statuses array it may detect that the result is too large to send in a + * single RPC, in which case it places the serialized array into a broadcast variable and then + * sends a serialized broadcast variable instead. This variable holds a reference to that + * broadcast variable in order to keep it from being garbage collected and to allow for it to be + * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. + */ + private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + + /** + * Counter tracking the number of partitions that have output. This is a performance optimization + * to avoid having to count the number of non-null entries in the `mapStatuses` array and should + * be equivalent to`mapStatuses.count(_ ne null)`. + */ + private[this] var _numAvailableOutputs: Int = 0 + + /** + * Register a map output. If there is already a registered location for the map output then it + * will be replaced by the new location. + */ + def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { + if (mapStatuses(mapId) == null) { + _numAvailableOutputs += 1 + invalidateSerializedMapOutputStatusCache() + } + mapStatuses(mapId) = status + } + + /** + * Remove the map output which was served by the specified block manager. + * This is a no-op if there is no registered map output or if the registered output is from a + * different block manager. + */ + def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all shuffle outputs associated with this host. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsOnHost(host: String): Unit = { + removeOutputsByFilter(x => x.host == host) + } + + /** + * Removes all map outputs associated with the specified executor. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists), as they are + * still registered with that execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = synchronized { + removeOutputsByFilter(x => x.executorId == execId) + } + + /** + * Removes all shuffle outputs which satisfies the filter. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized { + for (mapId <- 0 until mapStatuses.length) { + if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + } + + /** + * Number of partitions that have shuffle outputs. + */ + def numAvailableOutputs: Int = synchronized { + _numAvailableOutputs + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + */ + def findMissingPartitions(): Seq[Int] = synchronized { + val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) + assert(missing.size == numPartitions - _numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + missing + } + + /** + * Serializes the mapStatuses array into an efficient compressed format. See the comments on + * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the map statuses then serialization will only be performed in a single thread and all + * other threads will block until the cache is populated. + */ + def serializedMapStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int): Array[Byte] = synchronized { + if (cachedSerializedMapStatus eq null) { + val serResult = MapOutputTracker.serializeMapStatuses( + mapStatuses, broadcastManager, isLocal, minBroadcastSize) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + cachedSerializedMapStatus + } + + // Used in testing. + def hasCachedSerializedBroadcast: Boolean = synchronized { + cachedSerializedBroadcast != null + } + + /** + * Helper function which provides thread-safe access to the mapStatuses array. + * The function should NOT mutate the array. + */ + def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized { + f(mapStatuses) + } + + /** + * Clears the cached serialized map output statuses. + */ + def invalidateSerializedMapOutputStatusCache(): Unit = synchronized { + if (cachedSerializedBroadcast != null) { + // Prevent errors during broadcast cleanup from crashing the DAGScheduler (see SPARK-21444) + Utils.tryLogNonFatalError { + // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup + // RPCs to dead executors. + cachedSerializedBroadcast.destroy(blocking = false) + } + cachedSerializedBroadcast = null + } + cachedSerializedMapStatus = null + } +} + private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage @@ -62,37 +234,26 @@ private[spark] class MapOutputTrackerMasterEndpoint( } /** - * Class that keeps track of the location of the map output of - * a stage. This is abstract because different versions of MapOutputTracker - * (driver and executor) use different HashMap to store its metadata. - */ + * Class that keeps track of the location of the map output of a stage. This is abstract because the + * driver and executor have different versions of the MapOutputTracker. In principle the driver- + * and executor-side classes don't need to share a common base class; the current shared base class + * is maintained primarily for backwards-compatibility in order to avoid having to update existing + * test code. +*/ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ var trackerEndpoint: RpcEndpointRef = _ /** - * This HashMap has different behavior for the driver and the executors. - * - * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks. - * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the - * driver's corresponding HashMap. - * - * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a - * thread-safe map. - */ - protected val mapStatuses: Map[Int, Array[MapStatus]] - - /** - * Incremented every time a fetch fails so that client nodes know to clear - * their cache of map output locations if this happens. + * The driver-side counter is incremented every time that a map output is lost. This value is sent + * to executors as part of tasks, where executors compare the new epoch number to the highest + * epoch number that they received in the past. If the new epoch number is higher then executors + * will clear their local caches of map output statuses and will re-fetch (possibly updated) + * statuses from the driver. */ protected var epoch: Long = 0 protected val epochLock = new AnyRef - /** Remembers which map output locations are currently being fetched on an executor. */ - private val fetching = new HashSet[Int] - /** * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. @@ -116,14 +277,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** - * Called from executors to get the server URIs and output sizes for each shuffle block that - * needs to be read from a given reduce task. - * - * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. - */ + // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) @@ -139,135 +293,31 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) - } - } + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] /** - * Return statistics about all of the outputs for a given shuffle. + * Deletes map output status information for the specified shuffle stage. */ - def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { - val statuses = getStatuses(dep.shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) - } - } - new MapOutputStatistics(dep.shuffleId, totalSizes) - } - } - - /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize - * on this array when reading it, because on the driver, we may be changing it in place. - * - * (It would be nice to remove this restriction in the future.) - */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTime = System.currentTimeMillis - var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { - // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { - try { - fetching.wait() - } catch { - case e: InterruptedException => - } - } - - // Either while we waited the fetch happened successfully, or - // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - // We have to do the fetch, get others to wait for us. - fetching += shuffleId - } - } + def unregisterShuffle(shuffleId: Int): Unit - if (fetchedStatuses == null) { - // We won the race to fetch the statuses; do so - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - // This try-finally prevents hangs due to timeouts: - try { - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() - } - } - } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${System.currentTimeMillis - startTime} ms") - - if (fetchedStatuses != null) { - return fetchedStatuses - } else { - logError("Missing all output locations for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) - } - } else { - return statuses - } - } - - /** Called to get current epoch number. */ - def getEpoch: Long = { - epochLock.synchronized { - return epoch - } - } - - /** - * Called from executors to update the epoch number, potentially clearing old outputs - * because of a fetch failure. Each executor task calls this with the latest epoch - * number on the driver at the time it was created. - */ - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - epoch = newEpoch - mapStatuses.clear() - } - } - } - - /** Unregister shuffle data. */ - def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - } - - /** Stop the tracker. */ - def stop() { } + def stop() {} } /** - * MapOutputTracker for the driver. + * Driver-side class that keeps track of the location of the map output of a stage. + * + * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics + * for performing locality-aware reduce task scheduling. + * + * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine + * which tasks need to be run. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf, - broadcastManager: BroadcastManager, isLocal: Boolean) +private[spark] class MapOutputTrackerMaster( + conf: SparkConf, + broadcastManager: BroadcastManager, + isLocal: Boolean) extends MapOutputTracker(conf) { - /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ - private var cacheEpoch = epoch - // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt @@ -287,22 +337,13 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - // HashMaps for storing mapStatuses and cached serialized statuses in the driver. + // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. - protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala - private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + // Exposed for testing + val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // Kept in sync with cachedSerializedStatuses explicitly - // This is required so that the Broadcast variable remains in scope until we remove - // the shuffleId explicitly or implicitly. - private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() - - // This is to prevent multiple serializations of the same shuffle - which happens when - // there is a request storm when shuffle start. - private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() - // requests for map output statuses private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] @@ -348,8 +389,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) - context.reply(mapOutputStatuses) + val shuffleStatus = shuffleStatuses.get(shuffleId).head + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -363,59 +405,86 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, /** A poison endpoint that indicates MessageLoop should exit its message loop. */ private val PoisonPill = new GetMapOutputMessage(-99, null) - // Exposed for testing - private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + // Used only in unit tests. + private[spark] def getNumCachedSerializedBroadcast: Int = { + shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) + } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - // add in advance - shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - val array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - /** Register multiple map output information for the given shuffle */ - def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, statuses.clone()) - if (changeEpoch) { - incrementEpoch() - } + shuffleStatuses(shuffleId).addMapOutput(mapId, status) } /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - val arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - val array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMapOutput(mapId, bmAddress) + incrementEpoch() + case None => + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } /** Unregister shuffle data */ - override def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - cachedSerializedStatuses.remove(shuffleId) - cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) - shuffleIdLocks.remove(shuffleId) + def unregisterShuffle(shuffleId: Int) { + shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => + shuffleStatus.invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all shuffle outputs associated with this host. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsOnHost(host: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) } + incrementEpoch() + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } + incrementEpoch() } /** Check if the given shuffle is being tracked */ - def containsShuffle(shuffleId: Int): Boolean = { - cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) + + def getNumAvailableOutputs(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None + * if the MapOutputTrackerMaster doesn't know about this shuffle. + */ + def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = { + shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } } /** @@ -459,9 +528,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, fractionThreshold: Double) : Option[Array[BlockManagerId]] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses != null) { - statuses.synchronized { + val shuffleStatus = shuffleStatuses.get(shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location val locs = new HashMap[BlockManagerId, Long] @@ -502,77 +571,24 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } } - private def removeBroadcast(bcast: Broadcast[_]): Unit = { - if (null != bcast) { - broadcastManager.unbroadcast(bcast.id, - removeFromDriver = true, blocking = false) + /** Called to get current epoch number. */ + def getEpoch: Long = { + epochLock.synchronized { + return epoch } } - private def clearCachedBroadcast(): Unit = { - for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) - cachedSerializedBroadcast.clear() - } - - def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null - var retBytes: Array[Byte] = null - var epochGotten: Long = -1 - - // Check to see if we have a cached version, returns true if it does - // and has side effect of setting retBytes. If not returns false - // with side effect of setting statuses - def checkCachedStatuses(): Boolean = { - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - clearCachedBroadcast() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - retBytes = bytes - true - case None => - logDebug("cached status not found for : " + shuffleId) - statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus]) - epochGotten = epoch - false - } - } - } - - if (checkCachedStatuses()) return retBytes - var shuffleIdLock = shuffleIdLocks.get(shuffleId) - if (null == shuffleIdLock) { - val newLock = new Object() - // in general, this condition should be false - but good to be paranoid - val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) - shuffleIdLock = if (null != prevLock) prevLock else newLock - } - // synchronize so we only serialize/broadcast it once since multiple threads call - // in parallel - shuffleIdLock.synchronized { - // double check to make sure someone else didn't serialize and cache the same - // mapstatus while we were waiting on the synchronize - if (checkCachedStatuses()) return retBytes - - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, - isLocal, minSizeForBroadcast) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes - if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast - } else { - logInfo("Epoch changed, not caching!") - removeBroadcast(bcast) + // This method is only called in local-mode. + def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some (shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } - } - bytes + case None => + Seq.empty } } @@ -580,21 +596,121 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, mapOutputRequests.offer(PoisonPill) threadpool.shutdown() sendTracker(StopMapOutputTracker) - mapStatuses.clear() trackerEndpoint = null - cachedSerializedStatuses.clear() - clearCachedBroadcast() - shuffleIdLocks.clear() + shuffleStatuses.clear() } } /** - * MapOutputTracker for the executors, which fetches map output information from the driver's - * MapOutputTrackerMaster. + * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. + * Note that this is not used in local-mode; instead, local-mode Executors access the + * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon + * superclass). */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val mapStatuses: Map[Int, Array[MapStatus]] = + + val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + + /** Remembers which map output locations are currently being fetched on an executor. */ + private val fetching = new HashSet[Int] + + override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + try { + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis + var fetchedStatuses: Array[MapStatus] = null + fetching.synchronized { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try { + fetching.wait() + } catch { + case e: InterruptedException => + } + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. + fetching += shuffleId + } + } + + if (fetchedStatuses == null) { + // We won the race to fetch the statuses; do so + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${System.currentTimeMillis - startTime} ms") + + if (fetchedStatuses != null) { + fetchedStatuses + } else { + logError("Missing all output locations for shuffle " + shuffleId) + throw new MetadataFetchFailedException( + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) + } + } else { + statuses + } + } + + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int): Unit = { + mapStatuses.remove(shuffleId) + } + + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each executor task calls this with the latest epoch + * number on the driver at the time it was created. + */ + def updateEpoch(newEpoch: Long): Unit = { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch to " + newEpoch + " and clearing cache") + epoch = newEpoch + mapStatuses.clear() + } + } + } } private[spark] object MapOutputTracker extends Logging { @@ -683,7 +799,7 @@ private[spark] object MapOutputTracker extends Logging { * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - private def convertMapStatuses( + def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index f83f5278e8b8..debbd8d7c26c 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -108,11 +108,21 @@ class HashPartitioner(partitions: Int) extends Partitioner { class RangePartitioner[K : Ordering : ClassTag, V]( partitions: Int, rdd: RDD[_ <: Product2[K, V]], - private var ascending: Boolean = true) + private var ascending: Boolean = true, + val samplePointsPerPartitionHint: Int = 20) extends Partitioner { + // A constructor declared in order to maintain backward compatibility for Java, when we add the + // 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160. + // This is added to make sure from a bytecode point of view, there is still a 3-arg ctor. + def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = { + this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20) + } + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + require(samplePointsPerPartitionHint > 0, + s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint") private var ordering = implicitly[Ordering[K]] @@ -122,7 +132,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. - val sampleSize = math.min(20.0 * partitions, 1e6) + // Cast to double to avoid overflowing ints or longs + val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6) // Assume the input partitions are roughly balanced and over-sample a little bit. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) @@ -153,7 +164,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( val weight = (1.0 / fraction).toFloat candidates ++= reSampled.map(x => (x, weight)) } - RangePartitioner.determineBounds(candidates, partitions) + RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size)) } } } diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 29163e7f3054..477b01968c6e 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -94,21 +94,23 @@ private[spark] case class SSLOptions( * are supported by the current Java security provider for this protocol. */ private val supportedAlgorithms: Set[String] = if (enabledAlgorithms.isEmpty) { - Set() + Set.empty } else { var context: SSLContext = null - try { - context = SSLContext.getInstance(protocol.orNull) - /* The set of supported algorithms does not depend upon the keys, trust, or + if (protocol.isEmpty) { + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + } else { + try { + context = SSLContext.getInstance(protocol.get) + /* The set of supported algorithms does not depend upon the keys, trust, or rng, although they will influence which algorithms are eventually used. */ - context.init(null, null, null) - } catch { - case npe: NullPointerException => - logDebug("No SSL protocol specified") - context = SSLContext.getDefault - case nsa: NoSuchAlgorithmException => - logDebug(s"No support for requested SSL protocol ${protocol.get}") - context = SSLContext.getDefault + context.init(null, null, null) + } catch { + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } } val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet @@ -167,39 +169,39 @@ private[spark] object SSLOptions extends Logging { def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) - val port = conf.getOption(s"$ns.port").map(_.toInt) + val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) port.foreach { p => require(p >= 0, "Port number must be a non-negative value.") } - val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) + val keyStore = conf.getWithSubstitution(s"$ns.keyStore").map(new File(_)) .orElse(defaults.flatMap(_.keyStore)) - val keyStorePassword = conf.getOption(s"$ns.keyStorePassword") + val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") .orElse(defaults.flatMap(_.keyStorePassword)) - val keyPassword = conf.getOption(s"$ns.keyPassword") + val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") .orElse(defaults.flatMap(_.keyPassword)) - val keyStoreType = conf.getOption(s"$ns.keyStoreType") + val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") .orElse(defaults.flatMap(_.keyStoreType)) val needClientAuth = conf.getBoolean(s"$ns.needClientAuth", defaultValue = defaults.exists(_.needClientAuth)) - val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_)) + val trustStore = conf.getWithSubstitution(s"$ns.trustStore").map(new File(_)) .orElse(defaults.flatMap(_.trustStore)) - val trustStorePassword = conf.getOption(s"$ns.trustStorePassword") + val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") .orElse(defaults.flatMap(_.trustStorePassword)) - val trustStoreType = conf.getOption(s"$ns.trustStoreType") + val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") .orElse(defaults.flatMap(_.trustStoreType)) - val protocol = conf.getOption(s"$ns.protocol") + val protocol = conf.getWithSubstitution(s"$ns.protocol") .orElse(defaults.flatMap(_.protocol)) - val enabledAlgorithms = conf.getOption(s"$ns.enabledAlgorithms") + val enabledAlgorithms = conf.getWithSubstitution(s"$ns.enabledAlgorithms") .map(_.split(",").map(_.trim).filter(_.nonEmpty).toSet) .orElse(defaults.map(_.enabledAlgorithms)) .getOrElse(Set.empty) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 2a2ce0504dbb..e61f943af49f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -373,6 +373,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) } + /** Get an optional value, applying variable substitution. */ + private[spark] def getWithSubstitution(key: String): Option[String] = { + getOption(key).map(reader.substitute(_)) + } + /** Get all parameters as a list of pairs */ def getAll: Array[(String, String)] = { settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray @@ -543,6 +548,17 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } + if (contains("spark.cores.max") && contains("spark.executor.cores")) { + val totalCores = getInt("spark.cores.max", 1) + val executorCores = getInt("spark.executor.cores", 1) + val leftCores = totalCores % executorCores + if (leftCores != 0) { + logWarning(s"Total executor cores: ${totalCores} is not " + + s"divisible by cores per executor: ${executorCores}, " + + s"the left cores: ${leftCores} will not be allocated") + } + } + val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") @@ -579,7 +595,11 @@ private[spark] object SparkConf extends Logging { "are no longer accepted. To specify the equivalent now, one may use '64k'."), DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", - "Please use the new blacklisting options, spark.blacklist.*") + "Please use the new blacklisting options, spark.blacklist.*"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.shuffle.service.index.cache.entries", "2.3.0", + "Not used any more. Please use spark.shuffle.service.index.cache.size") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) @@ -590,6 +610,8 @@ private[spark] object SparkConf extends Logging { * * The alternates are used in the order defined in this map. If deprecated configs are * present in the user's configuration, a warning is logged. + * + * TODO: consolidate it with `ConfigBuilder.withAlternative`. */ private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( "spark.executor.userClassPathFirst" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0ec1bdd39b2f..cec61d85ccf3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -145,9 +145,8 @@ class SparkContext(config: SparkConf) extends Logging { this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) } - // NOTE: The below constructors could be consolidated using default arguments. Due to - // Scala bug SI-8479, however, this causes the compile step to fail when generating docs. - // Until we have a good workaround for that bug the constructors remain broken out. + // The following constructors are required when Java code accesses SparkContext directly. + // Please see SI-4278 /** * Alternative constructor that allows setting common Spark properties directly @@ -183,8 +182,6 @@ class SparkContext(config: SparkConf) extends Logging { // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - warnDeprecatedVersions() - /* ------------------------------------------------------------------------------------- * | Private variables. These variables keep the internal state of the context, and are | | not accessible by the outside world. They're mutable since we want to initialize all | @@ -195,6 +192,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _conf: SparkConf = _ private var _eventLogDir: Option[URI] = None private var _eventLogCodec: Option[String] = None + private var _listenerBus: LiveListenerBus = _ private var _env: SparkEnv = _ private var _jobProgressListener: JobProgressListener = _ private var _statusTracker: SparkStatusTracker = _ @@ -247,7 +245,7 @@ class SparkContext(config: SparkConf) extends Logging { def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events - private[spark] val listenerBus = new LiveListenerBus(this) + private[spark] def listenerBus: LiveListenerBus = _listenerBus // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( @@ -348,13 +346,6 @@ class SparkContext(config: SparkConf) extends Logging { value } - private def warnDeprecatedVersions(): Unit = { - val javaVersion = System.getProperty("java.version").split("[+.\\-]+", 3) - if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.10"))) { - logWarning("Support for Scala 2.10 is deprecated as of Spark 2.1.0") - } - } - /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN @@ -423,10 +414,12 @@ class SparkContext(config: SparkConf) extends Logging { if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") + _listenerBus = new LiveListenerBus(_conf) + // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. _jobProgressListener = new JobProgressListener(_conf) - listenerBus.addListener(jobProgressListener) + listenerBus.addToStatusQueue(jobProgressListener) // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) @@ -449,7 +442,7 @@ class SparkContext(config: SparkConf) extends Logging { _ui = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + Some(SparkUI.createLiveUI(this, _conf, _jobProgressListener, _env.securityManager, appName, startTime = startTime)) } else { // For tests, do not enable the UI @@ -529,7 +522,7 @@ class SparkContext(config: SparkConf) extends Logging { new EventLoggingListener(_applicationId, _applicationAttemptId, _eventLogDir.get, _conf, _hadoopConfiguration) logger.start() - listenerBus.addListener(logger) + listenerBus.addToEventLogQueue(logger) Some(logger) } else { None @@ -1393,6 +1386,8 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { + // TODO the context bound (<%) above should be replaced with simple type bound and implicit + // conversion but is a breaking change. This should be fixed in Spark 3.x. val param = new GrowableAccumulableParam[R, T] val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) @@ -1495,6 +1490,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Add a file to be downloaded with this Spark job on every node. * + * If a file is added during execution, it will not be available until the next TaskSet starts. + * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. @@ -1511,6 +1508,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Add a file to be downloaded with this Spark job on every node. * + * If a file is added during execution, it will not be available until the next TaskSet starts. + * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. @@ -1564,7 +1563,7 @@ class SparkContext(config: SparkConf) extends Logging { */ @DeveloperApi def addSparkListener(listener: SparkListenerInterface) { - listenerBus.addListener(listener) + listenerBus.addToSharedQueue(listener) } /** @@ -1734,6 +1733,7 @@ class SparkContext(config: SparkConf) extends Logging { * Return information about blocks stored in all of the slaves */ @DeveloperApi + @deprecated("This method may change or be removed in a future release.", "2.2.0") def getExecutorStorageStatus: Array[StorageStatus] = { assertNotStopped() env.blockManager.master.getStorageStatus @@ -1796,44 +1796,46 @@ class SparkContext(config: SparkConf) extends Logging { /** * Adds a JAR dependency for all tasks to be executed on this `SparkContext` in the future. + * + * If a jar is added during execution, it will not be available until the next TaskSet starts. + * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ def addJar(path: String) { + def addJarFile(file: File): String = { + try { + if (!file.exists()) { + throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") + } + if (file.isDirectory) { + throw new IllegalArgumentException( + s"Directory ${file.getAbsoluteFile} is not allowed for addJar") + } + env.rpcEnv.fileServer.addJar(file) + } catch { + case NonFatal(e) => + logError(s"Failed to add $path to Spark environment", e) + null + } + } + if (path == null) { logWarning("null specified as parameter to addJar") } else { - var key = "" - if (path.contains("\\")) { + val key = if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.rpcEnv.fileServer.addJar(new File(path)) + addJarFile(new File(path)) } else { val uri = new URI(path) // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies Utils.validateURL(uri) - key = uri.getScheme match { + uri.getScheme match { // A JAR file which exists only on the driver node - case null | "file" => - try { - val file = new File(uri.getPath) - if (!file.exists()) { - throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") - } - if (file.isDirectory) { - throw new IllegalArgumentException( - s"Directory ${file.getAbsoluteFile} is not allowed for addJar") - } - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case NonFatal(e) => - logError(s"Failed to add $path to Spark environment", e) - null - } + case null | "file" => addJarFile(new File(uri.getPath)) // A JAR file which exists locally on every worker node - case "local" => - "file:" + uri.getPath - case _ => - path + case "local" => "file:" + uri.getPath + case _ => path } } if (key != null) { @@ -1877,8 +1879,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def stop(): Unit = { if (LiveListenerBus.withinListenerThread.value) { - throw new SparkException( - s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") + throw new SparkException(s"Cannot stop SparkContext within listener bus thread.") } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. @@ -1938,6 +1939,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this + // `SparkContext` is stopped. + localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() @@ -2373,7 +2377,7 @@ class SparkContext(config: SparkConf) extends Logging { " parameter from breaking Spark's ability to find a valid constructor.") } } - listenerBus.addListener(listener) + listenerBus.addToSharedQueue(listener) logInfo(s"Registered listener $className") } } catch { @@ -2385,7 +2389,7 @@ class SparkContext(config: SparkConf) extends Logging { } } - listenerBus.start() + listenerBus.start(this, _env.metricsSystem) _listenerBusStarted = true } @@ -2599,9 +2603,9 @@ object SparkContext extends Logging { */ private[spark] val LEGACY_DRIVER_IDENTIFIER = "" - private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) + private implicit def arrayToArrayWritable[T <: Writable : ClassTag](arr: Traversable[T]) : ArrayWritable = { - def anyToWritable[U <% Writable](u: U): Writable = u + def anyToWritable[U <: Writable](u: U): Writable = u new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]], arr.map(x => anyToWritable(x)).toArray) @@ -2822,6 +2826,42 @@ object WritableConverter { // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. + // The following implicit declarations have been added on top of the very similar ones + // below in order to enable compatibility with Scala 2.12. Scala 2.12 deprecates eta + // expansion of zero-arg methods and thus won't match a no-arg method where it expects + // an implicit that is a function of no args. + + implicit val intWritableConverterFn: () => WritableConverter[Int] = + () => simpleWritableConverter[Int, IntWritable](_.get) + + implicit val longWritableConverterFn: () => WritableConverter[Long] = + () => simpleWritableConverter[Long, LongWritable](_.get) + + implicit val doubleWritableConverterFn: () => WritableConverter[Double] = + () => simpleWritableConverter[Double, DoubleWritable](_.get) + + implicit val floatWritableConverterFn: () => WritableConverter[Float] = + () => simpleWritableConverter[Float, FloatWritable](_.get) + + implicit val booleanWritableConverterFn: () => WritableConverter[Boolean] = + () => simpleWritableConverter[Boolean, BooleanWritable](_.get) + + implicit val bytesWritableConverterFn: () => WritableConverter[Array[Byte]] = { + () => simpleWritableConverter[Array[Byte], BytesWritable] { bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + } + } + + implicit val stringWritableConverterFn: () => WritableConverter[String] = + () => simpleWritableConverter[String, Text](_.toString) + + implicit def writableWritableConverterFn[T <: Writable : ClassTag]: () => WritableConverter[T] = + () => new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) + + // These implicits remain included for backwards-compatibility. They fulfill the + // same role as those above. + implicit def intWritableConverter(): WritableConverter[Int] = simpleWritableConverter[Int, IntWritable](_.get) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f4a59f069a5f..24928150315e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -177,7 +177,7 @@ object SparkEnv extends Logging { SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, - port, + Option(port), isLocal, numCores, ioEncryptionKey, @@ -194,7 +194,6 @@ object SparkEnv extends Logging { conf: SparkConf, executorId: String, hostname: String, - port: Int, numCores: Int, ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { @@ -203,7 +202,7 @@ object SparkEnv extends Logging { executorId, hostname, hostname, - port, + None, isLocal, numCores, ioEncryptionKey @@ -220,7 +219,7 @@ object SparkEnv extends Logging { executorId: String, bindAddress: String, advertiseAddress: String, - port: Int, + port: Option[Int], isLocal: Boolean, numUsableCores: Int, ioEncryptionKey: Option[Array[Byte]], @@ -243,17 +242,12 @@ object SparkEnv extends Logging { } val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, - securityManager, clientMode = !isDriver) + val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, + securityManager, numUsableCores, !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. - // In the non-driver case, the RPC env's address may be null since it may not be listening - // for incoming connections. if (isDriver) { conf.set("spark.driver.port", rpcEnv.address.port.toString) - } else if (rpcEnv.address != null) { - conf.set("spark.executor.port", rpcEnv.address.port.toString) - logInfo(s"Setting spark.executor.port to: ${rpcEnv.address.port.toString}") } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -426,7 +420,7 @@ object SparkEnv extends Logging { if (!conf.contains("spark.scheduler.mode")) { Seq(("spark.scheduler.mode", schedulingMode)) } else { - Seq[(String, String)]() + Seq.empty[(String, String)] } val sparkProperties = (conf.getAll ++ schedulerMode).sorted diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 8cd1d1c96aa0..01d8973e1bb0 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -110,10 +110,10 @@ private[spark] class TaskContextImpl( /** Marks the task as completed and triggers the completion listeners. */ @GuardedBy("this") - private[spark] def markTaskCompleted(): Unit = synchronized { + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true - invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { _.onTaskCompletion(this) } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 3f912dc19151..a80016dd22fc 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -58,8 +58,8 @@ private[spark] object TestUtils { def createJarWithClasses( classNames: Seq[String], toStringValue: String = "", - classNamesWithBase: Seq[(String, String)] = Seq(), - classpathUrls: Seq[URL] = Seq()): URL = { + classNamesWithBase: Seq[(String, String)] = Seq.empty, + classpathUrls: Seq[URL] = Seq.empty): URL = { val tempDir = Utils.createTempDir() val files1 = for (name <- classNames) yield { createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) @@ -137,7 +137,7 @@ private[spark] object TestUtils { val options = if (classpathUrls.nonEmpty) { Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) } else { - Seq() + Seq.empty } compiler.getTask(null, null, null, options.asJava, null, Arrays.asList(sourceFile)).call() @@ -160,7 +160,7 @@ private[spark] object TestUtils { destDir: File, toStringValue: String = "", baseClass: String = null, - classpathUrls: Seq[URL] = Seq()): File = { + classpathUrls: Seq[URL] = Seq.empty): File = { val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") val sourceFile = new JavaSourceFromString(className, "public class " + className + extendsText + " implements java.io.Serializable {" + diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index b71af0d42cdb..b6df5663d919 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -22,8 +22,8 @@ import java.lang.{Double => JDouble} import scala.language.implicitConversions import scala.reflect.ClassTag -import org.apache.spark.annotation.Since import org.apache.spark.Partitioner +import org.apache.spark.annotation.Since import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 9481156bc93a..f1936bf58728 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -757,6 +757,12 @@ class JavaSparkContext(val sc: SparkContext) */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) + /** + * Set a human readable description of the current job. + * @since 2.3.0 + */ + def setJobDescription(value: String): Unit = sc.setJobDescription(value) + /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index f820401da2fc..d6506231b8d7 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -56,9 +56,9 @@ private[spark] object JavaUtils { val ui = underlying.iterator var prev : Option[A] = None - def hasNext: Boolean = ui.hasNext + override def hasNext: Boolean = ui.hasNext - def next(): Entry[A, B] = { + override def next(): Entry[A, B] = { val (k, v) = ui.next() prev = Some(k) new ju.Map.Entry[A, B] { @@ -74,7 +74,7 @@ private[spark] object JavaUtils { } } - def remove() { + override def remove() { prev match { case Some(k) => underlying match { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b0dd2fc187ba..f6293c0dc509 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -48,7 +48,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions: Array[Partition] = firstParent.partitions @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func, bufferSize, reuse_worker) + val runner = PythonRunner(func, bufferSize, reuseWorker) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -83,306 +83,9 @@ private[spark] case class PythonFunction( */ private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) -private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { - new PythonRunner( - Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0))) - } -} - -/** - * A helper class to run Python mapPartition/UDFs in Spark. - * - * funcs is a list of independent Python functions, each one of them is a list of chained Python - * functions (from bottom to top). - */ -private[spark] class PythonRunner( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuse_worker: Boolean, - isUDF: Boolean, - argOffsets: Array[Array[Int]]) - extends Logging { - - require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") - - // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.funcs.head.envVars - private val pythonExec = funcs.head.funcs.head.pythonExec - private val pythonVer = funcs.head.funcs.head.pythonVer - - // TODO: support accumulator in multiple UDF - private val accumulator = funcs.head.funcs.head.accumulator - - def compute( - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext): Iterator[Array[Byte]] = { - val startTime = System.currentTimeMillis - val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread - if (reuse_worker) { - envVars.put("SPARK_REUSE_WORKER", "1") - } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool - @volatile var released = false - - // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) - - context.addTaskCompletionListener { context => - writerThread.shutdownOnTaskCompletion() - if (!reuse_worker || !released) { - try { - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } - } - - writerThread.start() - new MonitorThread(env, worker, context).start() - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - val stdoutIterator = new Iterator[Array[Byte]] { - override def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - private def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case 0 => Array.empty[Byte] - case SpecialLengths.TIMING_DATA => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.getOrElse(null)) - case SpecialLengths.END_OF_DATA_SECTION => - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - accumulator.add(update) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released = true - } - } - null - } - } catch { - - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - - case e: Exception if env.isStopped => - logDebug("Exception thrown after context is stopped", e) - null // exit silently - - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - } - - var _nextObj = read() - - override def hasNext: Boolean = _nextObj != null - } - new InterruptibleIterator(context, stdoutIterator) - } - - /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the - * Python process. - */ - class WriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { - - @volatile private var _exception: Exception = null - - private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - - setDaemon(true) - - /** Contains the exception thrown while writing the parent iterator to the Python process. */ - def exception: Option[Exception] = Option(_exception) - - /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ - def shutdownOnTaskCompletion() { - assert(context.isCompleted) - this.interrupt() - } - - override def run(): Unit = Utils.logUncaughtExceptions { - try { - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // Write out the TaskContextInfo - dataOut.writeInt(context.stageId()) - dataOut.writeInt(context.partitionId()) - dataOut.writeInt(context.attemptNumber()) - dataOut.writeLong(context.taskAttemptId()) - // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val cnt = toRemove.size + newBids.diff(oldBids).size - dataOut.writeInt(cnt) - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - dataOut.flush() - // Serialized command: - if (isUDF) { - dataOut.writeInt(1) - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) - } - } - } else { - dataOut.writeInt(0) - val command = funcs.head.funcs.head.command - dataOut.writeInt(command.length) - dataOut.write(command) - } - // Data values - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - } catch { - case e: Exception if context.isCompleted || context.isInterrupted => - logDebug("Exception thrown after task completion (likely due to cleanup)", e) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see org.apache.spark.executor.Executor). - _exception = e - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } - } - } - - /** - * It is necessary to have a monitor thread for python workers if the user cancels with - * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - * threads can block indefinitely. - */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) - extends Thread(s"Worker Monitor for $pythonExec") { - - setDaemon(true) - - override def run() { - // Kill the worker if it is interrupted, checking until task completion. - // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.isInterrupted && !context.isCompleted) { - Thread.sleep(2000) - } - if (!context.isCompleted) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } -} - /** Thrown for exceptions in user Python code. */ -private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause) +private[spark] class PythonException(msg: String, cause: Exception) + extends RuntimeException(msg, cause) /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. @@ -399,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private object SpecialLengths { - val END_OF_DATA_SECTION = -1 - val PYTHON_EXCEPTION_THROWN = -2 - val TIMING_DATA = -3 - val END_OF_STREAM = -4 - val NULL = -5 -} - private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker @@ -683,7 +378,7 @@ private[spark] object PythonRDD extends Logging { * Create a socket server and a background thread to serve the data in `items`, * * The socket server can only accept one connection, or close if no connection - * in 3 seconds. + * in 15 seconds. * * Once a connection comes in, it tries to serialize all the data in `items` * and send them into this connection. @@ -692,8 +387,8 @@ private[spark] object PythonRDD extends Logging { */ def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) - // Close the socket if no connection in 3 seconds - serverSocket.setSoTimeout(3000) + // Close the socket if no connection in 15 seconds + serverSocket.setSoTimeout(15000) new Thread(threadName) { setDaemon(true) @@ -879,7 +574,7 @@ private[spark] class PythonAccumulatorV2( private val serverPort: Int) extends CollectionAccumulator[Array[Byte]] { - Utils.checkHost(serverHost, "Expected hostname") + Utils.checkHost(serverHost) val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) @@ -974,6 +669,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } } } + super.finalize() } } // scalastyle:on no.finalize diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala new file mode 100644 index 000000000000..3688a149443c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io._ +import java.net._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.util._ + + +/** + * Enumerate the type of command that will be sent to the Python worker + */ +private[spark] object PythonEvalType { + val NON_UDF = 0 + val SQL_BATCHED_UDF = 1 + val SQL_PANDAS_UDF = 2 +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). + */ +private[spark] abstract class BasePythonRunner[IN, OUT]( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + // All the Python functions should have the same exec, version and envvars. + protected val envVars = funcs.head.funcs.head.envVars + protected val pythonExec = funcs.head.funcs.head.pythonExec + protected val pythonVer = funcs.head.funcs.head.pythonVer + + // TODO: support accumulator in multiple UDF + protected val accumulator = funcs.head.funcs.head.accumulator + + def compute( + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): Iterator[OUT] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuseWorker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) + // Whether is the worker released into idle pool + val released = new AtomicBoolean(false) + + // Start a thread to feed the process input from our parent's iterator + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + + context.addTaskCompletionListener { _ => + writerThread.shutdownOnTaskCompletion() + if (!reuseWorker || !released.get) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + new MonitorThread(env, worker, context).start() + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val stdoutIterator = newReaderIterator( + stream, writerThread, startTime, env, worker, released, context) + new InterruptibleIterator(context, stdoutIterator) + } + + protected def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[OUT] + + /** + * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Python process. + */ + abstract class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Exception = null + + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) + + setDaemon(true) + + /** Contains the exception thrown while writing the parent iterator to the Python process. */ + def exception: Option[Exception] = Option(_exception) + + /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + def shutdownOnTaskCompletion() { + assert(context.isCompleted) + this.interrupt() + } + + /** + * Writes a command section to the stream connected to the Python worker. + */ + protected def writeCommand(dataOut: DataOutputStream): Unit + + /** + * Writes input data to the stream connected to the Python worker. + */ + protected def writeIteratorToStream(dataOut: DataOutputStream): Unit + + override def run(): Unit = Utils.logUncaughtExceptions { + try { + TaskContext.setTaskContext(context) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + dataOut.flush() + + dataOut.writeInt(evalType) + writeCommand(dataOut) + writeIteratorToStream(dataOut) + + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } catch { + case e: Exception if context.isCompleted || context.isInterrupted => + logDebug("Exception thrown after task completion (likely due to cleanup)", e) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + + case e: Exception => + // We must avoid throwing exceptions here, because the thread uncaught exception handler + // will kill the whole executor (see org.apache.spark.executor.Executor). + _exception = e + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } + } + } + + abstract class ReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext) + extends Iterator[OUT] { + + private var nextObj: OUT = _ + private var eos = false + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false + } + } + + override def next(): OUT = { + if (hasNext) { + val obj = nextObj + nextObj = null.asInstanceOf[OUT] + obj + } else { + Iterator.empty.next() + } + } + + /** + * Reads next object from the stream. + * When the stream reaches end of data, needs to process the following sections, + * and then returns null. + */ + protected def read(): OUT + + protected def handleTimingData(): Unit = { + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + } + + protected def handlePythonException(): PythonException = { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + new PythonException(new String(obj, StandardCharsets.UTF_8), + writerThread.exception.getOrElse(null)) + } + + protected def handleEndOfDataSection(): Unit = { + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + accumulator.add(update) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuseWorker) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released.set(true) + } + } + eos = true + } + + protected val handleException: PartialFunction[Throwable, OUT] = { + case e: Exception if context.isInterrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) + + case e: Exception if env.isStopped => + logDebug("Exception thrown after context is stopped", e) + null.asInstanceOf[OUT] // exit silently + + case e: Exception if writerThread.exception.isDefined => + logError("Python worker exited unexpectedly (crashed)", e) + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get + + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + } + + /** + * It is necessary to have a monitor thread for python workers if the user cancels with + * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the + * threads can block indefinitely. + */ + class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + extends Thread(s"Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + // Kill the worker if it is interrupted, checking until task completion. + // TODO: This has a race condition if interruption occurs, as completed may still become true. + while (!context.isInterrupted && !context.isCompleted) { + Thread.sleep(2000) + } + if (!context.isCompleted) { + try { + logWarning("Incomplete task interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } +} + +private[spark] object PythonRunner { + + def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + } +} + +/** + * A helper class to run Python mapPartition in Spark. + */ +private[spark] class PythonRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean) + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +private[spark] object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 + val END_OF_STREAM = -4 + val NULL = -5 + val START_ARROW_STREAM = -6 +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index c4e55b5e8902..92e228a9dd10 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 6a5e6f7c5afb..fc595ae9e456 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -22,8 +22,8 @@ import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.nio.charset.StandardCharsets import java.util.Arrays -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark._ import org.apache.spark.internal.Logging diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 6e4eab4b805c..01e64b6972ae 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -35,6 +35,16 @@ import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ private[spark] object SerDeUtil extends Logging { + class ByteArrayConstructor extends net.razorvine.pickle.objects.ByteArrayConstructor { + override def construct(args: Array[Object]): Object = { + // Deal with an empty byte array pickled by Python 3. + if (args.length == 0) { + Array.emptyByteArray + } else { + super.construct(args) + } + } + } // Unpickle array.array generated by Python 2.6 class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { // /* Description of types */ @@ -55,13 +65,12 @@ private[spark] object SerDeUtil extends Logging { // {'d', sizeof(double), d_getitem, d_setitem}, // {'\0', 0, 0, 0} /* Sentinel */ // }; - // TODO: support Py_UNICODE with 2 bytes val machineCodes: Map[Char, Int] = if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { - Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9, + Map('B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9, 'L' -> 11, 'l' -> 13, 'f' -> 15, 'd' -> 17, 'u' -> 21 ) } else { - Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8, + Map('B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8, 'L' -> 10, 'l' -> 12, 'f' -> 14, 'd' -> 16, 'u' -> 20 ) } @@ -72,7 +81,30 @@ private[spark] object SerDeUtil extends Logging { val typecode = args(0).asInstanceOf[String].charAt(0) // This must be ISO 8859-1 / Latin 1, not UTF-8, to interoperate correctly val data = args(1).asInstanceOf[String].getBytes(StandardCharsets.ISO_8859_1) - construct(typecode, machineCodes(typecode), data) + if (typecode == 'c') { + // It seems like the pickle of pypy uses the similar protocol to Python 2.6, which uses + // a string for array data instead of list as Python 2.7, and handles an array of + // typecode 'c' as 1-byte character. + val result = new Array[Char](data.length) + var i = 0 + while (i < data.length) { + result(i) = data(i).toChar + i += 1 + } + result + } else { + construct(typecode, machineCodes(typecode), data) + } + } else if (args.length == 2 && args(0) == "l") { + // On Python 2, an array of typecode 'l' should be handled as long rather than int. + val values = args(1).asInstanceOf[JArrayList[_]] + val result = new Array[Long](values.size) + var i = 0 + while (i < values.size) { + result(i) = values.get(i).asInstanceOf[Number].longValue() + i += 1 + } + result } else { super.construct(args) } @@ -86,6 +118,10 @@ private[spark] object SerDeUtil extends Logging { synchronized{ if (!initialized) { Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + Unpickler.registerConstructor("__builtin__", "bytearray", new ByteArrayConstructor()) + Unpickler.registerConstructor("builtins", "bytearray", new ByteArrayConstructor()) + Unpickler.registerConstructor("__builtin__", "bytes", new ByteArrayConstructor()) + Unpickler.registerConstructor("_codecs", "encode", new ByteArrayConstructor()) initialized = true } } diff --git a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala index 3432700f1160..b8c4ff9d477a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.r -import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger /** JVM object ID wrapper */ private[r] case class JVMObjectId(id: String) { @@ -37,13 +37,7 @@ private[r] class JVMObjectTracker { /** * Returns the JVM object associated with the input key or None if not found. */ - final def get(id: JVMObjectId): Option[Object] = this.synchronized { - if (objMap.containsKey(id)) { - Some(objMap.get(id)) - } else { - None - } - } + final def get(id: JVMObjectId): Option[Object] = Option(objMap.get(id)) /** * Returns the JVM object associated with the input key or throws an exception if not found. @@ -67,13 +61,7 @@ private[r] class JVMObjectTracker { /** * Removes and returns a JVM object with the specific ID from the tracker, or None if not found. */ - final def remove(id: JVMObjectId): Option[Object] = this.synchronized { - if (objMap.containsKey(id)) { - Some(objMap.remove(id)) - } else { - None - } - } + final def remove(id: JVMObjectId): Option[Object] = Option(objMap.remove(id)) /** * Number of JVM objects being tracked. diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index cfd37ac54ba2..18fc595301f4 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -26,9 +26,9 @@ import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.ChannelHandler.Sharable import io.netty.handler.timeout.ReadTimeoutException +import org.apache.spark.SparkConf import org.apache.spark.api.r.SerDe._ import org.apache.spark.internal.Logging -import org.apache.spark.SparkConf import org.apache.spark.util.{ThreadUtils, Utils} /** diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index dad928cdcfd0..537ab57f9664 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -128,8 +128,7 @@ private[spark] object SerDe { } def readBoolean(in: DataInputStream): Boolean = { - val intVal = in.readInt() - if (intVal == 0) false else true + in.readInt() != 0 } def readDate(in: DataInputStream): Date = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 039df75ce74f..67e993c7f02e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ -import org.apache.spark.util.{ByteBufferInputStream, Utils} +import org.apache.spark.util.Utils import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala new file mode 100644 index 000000000000..ecc82d7ac800 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File + +import org.apache.commons.lang3.StringUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.util.{MutableURLClassLoader, Utils} + +private[deploy] object DependencyUtils { + + def resolveMavenDependencies( + packagesExclusions: String, + packages: String, + repositories: String, + ivyRepoPath: String): String = { + val exclusions: Seq[String] = + if (!StringUtils.isBlank(packagesExclusions)) { + packagesExclusions.split(",") + } else { + Nil + } + // Create the IvySettings, either load from file or build defaults + val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => + SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) + }.getOrElse { + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + } + + SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) + } + + def resolveAndDownloadJars( + jars: String, + userJar: String, + sparkConf: SparkConf, + hadoopConf: Configuration, + secMgr: SecurityManager): String = { + val targetDir = Utils.createTempDir() + Option(jars) + .map { + resolveGlobPaths(_, hadoopConf) + .split(",") + .filterNot(_.contains(userJar.split("/").last)) + .mkString(",") + } + .filterNot(_ == "") + .map(downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr)) + .orNull + } + + def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = { + if (jars != null) { + for (jar <- jars.split(",")) { + SparkSubmit.addJarToClasspath(jar, loader) + } + } + } + + /** + * Download a list of remote files to temp local files. If the file is local, the original file + * will be returned. + * + * @param fileList A comma separated file list. + * @param targetDir A temporary directory for which downloaded files. + * @param sparkConf Spark configuration. + * @param hadoopConf Hadoop configuration. + * @param secMgr Spark security manager. + * @return A comma separated local files list. + */ + def downloadFileList( + fileList: String, + targetDir: File, + sparkConf: SparkConf, + hadoopConf: Configuration, + secMgr: SecurityManager): String = { + require(fileList != null, "fileList cannot be null.") + Utils.stringToSeq(fileList) + .map(downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr)) + .mkString(",") + } + + /** + * Download a file from the remote to a local temporary directory. If the input path points to + * a local path, returns it with no operation. + * + * @param path A file path from where the files will be downloaded. + * @param targetDir A temporary directory for which downloaded files. + * @param sparkConf Spark configuration. + * @param hadoopConf Hadoop configuration. + * @param secMgr Spark security manager. + * @return Path to the local file. + */ + def downloadFile( + path: String, + targetDir: File, + sparkConf: SparkConf, + hadoopConf: Configuration, + secMgr: SecurityManager): String = { + require(path != null, "path cannot be null.") + val uri = Utils.resolveURI(path) + + uri.getScheme match { + case "file" | "local" => path + case "http" | "https" | "ftp" if Utils.isTesting => + // This is only used for SparkSubmitSuite unit test. Instead of downloading file remotely, + // return a dummy local path instead. + val file = new File(uri.getPath) + new File(targetDir, file.getName).toURI.toString + case _ => + val fname = new Path(uri).getName() + val localFile = Utils.doFetchFile(uri.toString(), targetDir, fname, sparkConf, secMgr, + hadoopConf) + localFile.toURI().toString() + } + } + + def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { + require(paths != null, "paths cannot be null.") + Utils.stringToSeq(paths).flatMap { path => + val uri = Utils.resolveURI(path) + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(path) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(path)) + } + }.mkString(",") + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index ac09c6c497f8..49a319abb323 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,7 +24,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable @@ -34,6 +34,16 @@ private[deploy] object DeployMessages { // Worker to Master + /** + * @param id the worker id + * @param host the worker host + * @param port the worker post + * @param worker the worker endpoint ref + * @param cores the core number of worker + * @param memory the memory size of worker + * @param workerWebUiUrl the worker Web UI address + * @param masterAddress the master address used by the worker to connect + */ case class RegisterWorker( id: String, host: String, @@ -41,9 +51,10 @@ private[deploy] object DeployMessages { worker: RpcEndpointRef, cores: Int, memory: Int, - workerWebUiUrl: String) + workerWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } @@ -80,8 +91,16 @@ private[deploy] object DeployMessages { sealed trait RegisterWorkerResponse - case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage - with RegisterWorkerResponse + /** + * @param master the master ref + * @param masterWebUiUrl the master Web UI address + * @param masterAddress the master address used by the worker to connect. It should be + * [[RegisterWorker.masterAddress]]. + */ + case class RegisteredWorker( + master: RpcEndpointRef, + masterWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage with RegisterWorkerResponse case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse @@ -131,7 +150,7 @@ private[deploy] object DeployMessages { // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { - Utils.checkHostPort(hostPort, "Required hostport") + Utils.checkHostPort(hostPort) } case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -139,6 +158,8 @@ private[deploy] object DeployMessages { case class ApplicationRemoved(message: String) + case class WorkerRemoved(id: String, host: String, message: String) + // DriverClient <-> Master case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage @@ -183,7 +204,7 @@ private[deploy] object DeployMessages { completedDrivers: Array[DriverInfo], status: MasterState) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) def uri: String = "spark://" + host + ":" + port @@ -201,7 +222,7 @@ private[deploy] object DeployMessages { drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 8d491ddf6e09..f975fa5cb4e2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -56,7 +56,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private var server: TransportServer = _ - private val shuffleServiceSource = new ExternalShuffleServiceSource(blockHandler) + private val shuffleServiceSource = new ExternalShuffleServiceSource /** Create a new shuffle block handler. Factored out for subclasses to override. */ protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = { @@ -83,6 +83,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana } server = transportContext.createServer(port, bootstraps.asJava) + shuffleServiceSource.registerMetricSet(server.getAllMetrics) + shuffleServiceSource.registerMetricSet(blockHandler.getAllMetrics) masterMetricsSystem.registerSource(shuffleServiceSource) masterMetricsSystem.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleServiceSource.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleServiceSource.scala index 357a9769311a..ccfc97439878 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleServiceSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleServiceSource.scala @@ -19,19 +19,19 @@ package org.apache.spark.deploy import javax.annotation.concurrent.ThreadSafe -import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.{MetricRegistry, MetricSet} import org.apache.spark.metrics.source.Source -import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler /** * Provides metrics source for external shuffle service */ @ThreadSafe -private class ExternalShuffleServiceSource -(blockHandler: ExternalShuffleBlockHandler) extends Source { +private class ExternalShuffleServiceSource extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "shuffleService" - metricRegistry.registerAll(blockHandler.getAllMetrics) + def registerMetricSet(metricSet: MetricSet): Unit = { + metricRegistry.registerAll(metricSet) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 220b20bf7cbd..721269616657 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -21,30 +21,65 @@ import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.master._ +import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.ExecutorRunner private[deploy] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo): JObject = { - ("id" -> obj.id) ~ - ("host" -> obj.host) ~ - ("port" -> obj.port) ~ - ("webuiaddress" -> obj.webUiAddress) ~ - ("cores" -> obj.cores) ~ - ("coresused" -> obj.coresUsed) ~ - ("coresfree" -> obj.coresFree) ~ - ("memory" -> obj.memory) ~ - ("memoryused" -> obj.memoryUsed) ~ - ("memoryfree" -> obj.memoryFree) ~ - ("state" -> obj.state.toString) ~ - ("lastheartbeat" -> obj.lastHeartbeat) - } + /** + * Export the [[WorkerInfo]] to a Json object. A [[WorkerInfo]] consists of the information of a + * worker. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the worker + * `host` the host that the worker is running on + * `port` the port that the worker is bound to + * `webuiaddress` the address used in web UI + * `cores` total cores of the worker + * `coresused` allocated cores of the worker + * `coresfree` free cores of the worker + * `memory` total memory of the worker + * `memoryused` allocated memory of the worker + * `memoryfree` free memory of the worker + * `state` state of the worker, see [[WorkerState]] + * `lastheartbeat` time in milliseconds that the latest heart beat message from the + * worker is received + */ + def writeWorkerInfo(obj: WorkerInfo): JObject = { + ("id" -> obj.id) ~ + ("host" -> obj.host) ~ + ("port" -> obj.port) ~ + ("webuiaddress" -> obj.webUiAddress) ~ + ("cores" -> obj.cores) ~ + ("coresused" -> obj.coresUsed) ~ + ("coresfree" -> obj.coresFree) ~ + ("memory" -> obj.memory) ~ + ("memoryused" -> obj.memoryUsed) ~ + ("memoryfree" -> obj.memoryFree) ~ + ("state" -> obj.state.toString) ~ + ("lastheartbeat" -> obj.lastHeartbeat) + } + /** + * Export the [[ApplicationInfo]] to a Json objec. An [[ApplicationInfo]] consists of the + * information of an application. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the application + * `starttime` time in milliseconds that the application starts + * `name` the description of the application + * `cores` total cores granted to the application + * `user` name of the user who submitted the application + * `memoryperslave` minimal memory in MB required to each executor + * `submitdate` time in Date that the application is submitted + * `state` state of the application, see [[ApplicationState]] + * `duration` time in milliseconds that the application has been running + */ def writeApplicationInfo(obj: ApplicationInfo): JObject = { - ("starttime" -> obj.startTime) ~ ("id" -> obj.id) ~ + ("starttime" -> obj.startTime) ~ ("name" -> obj.desc.name) ~ - ("cores" -> obj.desc.maxCores) ~ + ("cores" -> obj.coresGranted) ~ ("user" -> obj.desc.user) ~ ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ @@ -52,14 +87,36 @@ private[deploy] object JsonProtocol { ("duration" -> obj.duration) } + /** + * Export the [[ApplicationDescription]] to a Json object. An [[ApplicationDescription]] consists + * of the description of an application. + * + * @return a Json object containing the following fields: + * `name` the description of the application + * `cores` max cores that can be allocated to the application, 0 means unlimited + * `memoryperslave` minimal memory in MB required to each executor + * `user` name of the user who submitted the application + * `command` the command string used to submit the application + */ def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ - ("cores" -> obj.maxCores) ~ + ("cores" -> obj.maxCores.getOrElse(0)) ~ ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("user" -> obj.user) ~ ("command" -> obj.command.toString) } + /** + * Export the [[ExecutorRunner]] to a Json object. An [[ExecutorRunner]] consists of the + * information of an executor. + * + * @return a Json object containing the following fields: + * `id` an integer identifier of the executor + * `memory` memory in MB allocated to the executor + * `appid` a string identifier of the application that the executor is working on + * `appdesc` a Json object of the [[ApplicationDescription]] of the application that the + * executor is working on + */ def writeExecutorRunner(obj: ExecutorRunner): JObject = { ("id" -> obj.execId) ~ ("memory" -> obj.memory) ~ @@ -67,18 +124,59 @@ private[deploy] object JsonProtocol { ("appdesc" -> writeApplicationDescription(obj.appDesc)) } + /** + * Export the [[DriverInfo]] to a Json object. A [[DriverInfo]] consists of the information of a + * driver. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the driver + * `starttime` time in milliseconds that the driver starts + * `state` state of the driver, see [[DriverState]] + * `cores` cores allocated to the driver + * `memory` memory in MB allocated to the driver + * `submitdate` time in Date that the driver is created + * `worker` identifier of the worker that the driver is running on + * `mainclass` main class of the command string that started the driver + */ def writeDriverInfo(obj: DriverInfo): JObject = { ("id" -> obj.id) ~ ("starttime" -> obj.startTime.toString) ~ ("state" -> obj.state.toString) ~ ("cores" -> obj.desc.cores) ~ - ("memory" -> obj.desc.mem) + ("memory" -> obj.desc.mem) ~ + ("submitdate" -> obj.submitDate.toString) ~ + ("worker" -> obj.worker.map(_.id).getOrElse("None")) ~ + ("mainclass" -> obj.desc.command.arguments(2)) } + /** + * Export the [[MasterStateResponse]] to a Json object. A [[MasterStateResponse]] consists the + * information of a master node. + * + * @return a Json object containing the following fields: + * `url` the url of the master node + * `workers` a list of Json objects of [[WorkerInfo]] of the workers allocated to the + * master + * `aliveworkers` size of alive workers allocated to the master + * `cores` total cores available of the master + * `coresused` cores used by the master + * `memory` total memory available of the master + * `memoryused` memory used by the master + * `activeapps` a list of Json objects of [[ApplicationInfo]] of the active applications + * running on the master + * `completedapps` a list of Json objects of [[ApplicationInfo]] of the applications + * completed in the master + * `activedrivers` a list of Json objects of [[DriverInfo]] of the active drivers of the + * master + * `completeddrivers` a list of Json objects of [[DriverInfo]] of the completed drivers + * of the master + * `status` status of the master, see [[MasterState]] + */ def writeMasterState(obj: MasterStateResponse): JObject = { val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ + ("aliveworkers" -> aliveWorkers.length) ~ ("cores" -> aliveWorkers.map(_.cores).sum) ~ ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ ("memory" -> aliveWorkers.map(_.memory).sum) ~ @@ -86,9 +184,27 @@ private[deploy] object JsonProtocol { ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ + ("completeddrivers" -> obj.completedDrivers.toList.map(writeDriverInfo)) ~ ("status" -> obj.status.toString) } + /** + * Export the [[WorkerStateResponse]] to a Json object. A [[WorkerStateResponse]] consists the + * information of a worker node. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the worker node + * `masterurl` url of the master node of the worker + * `masterwebuiurl` the address used in web UI of the master node of the worker + * `cores` total cores of the worker + * `coreused` used cores of the worker + * `memory` total memory of the worker + * `memoryused` used memory of the worker + * `executors` a list of Json objects of [[ExecutorRunner]] of the executors running on + * the worker + * `finishedexecutors` a list of Json objects of [[ExecutorRunner]] of the finished + * executors of the worker + */ def writeWorkerState(obj: WorkerStateResponse): JObject = { ("id" -> obj.workerId) ~ ("masterurl" -> obj.masterUrl) ~ @@ -97,7 +213,7 @@ private[deploy] object JsonProtocol { ("coresused" -> obj.coresUsed) ~ ("memory" -> obj.memory) ~ ("memoryused" -> obj.memoryUsed) ~ - ("executors" -> obj.executors.toList.map(writeExecutorRunner)) ~ - ("finishedexecutors" -> obj.finishedExecutors.toList.map(writeExecutorRunner)) + ("executors" -> obj.executors.map(writeExecutorRunner)) ~ + ("finishedexecutors" -> obj.finishedExecutors.map(writeExecutorRunner)) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index a8f732b11f6c..7aca305783a7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy import java.io.File import java.net.URI -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.spark.{SparkConf, SparkUserAppException} diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 050778a895c0..7d356e8fc1c0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -92,6 +92,9 @@ private[deploy] object RPackageUtils extends Logging { * Exposed for testing. */ private[deploy] def checkManifestForR(jar: JarFile): Boolean = { + if (jar.getManifest == null) { + return false + } val manifest = jar.getManifest.getMainAttributes manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true" } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 9cc321af4bde..53775db251bc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,12 +17,15 @@ package org.apache.spark.deploy -import java.io.IOException +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException} import java.security.PrivilegedExceptionAction import java.text.DateFormat import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ +import scala.collection.immutable.Map +import scala.collection.mutable +import scala.collection.mutable.HashMap import scala.util.control.NonFatal import com.google.common.primitives.Longs @@ -73,39 +76,31 @@ class SparkHadoopUtil extends Logging { } } - /** * Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop * configuration. */ def appendS3AndSparkHadoopConfigurations(conf: SparkConf, hadoopConf: Configuration): Unit = { - // Note: this null check is around more than just access to the "conf" object to maintain - // the behavior of the old implementation of this code, for backwards compatibility. - if (conf != null) { - // Explicitly check for S3 environment variables - val keyId = System.getenv("AWS_ACCESS_KEY_ID") - val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") - if (keyId != null && accessKey != null) { - hadoopConf.set("fs.s3.awsAccessKeyId", keyId) - hadoopConf.set("fs.s3n.awsAccessKeyId", keyId) - hadoopConf.set("fs.s3a.access.key", keyId) - hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey) - hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey) - hadoopConf.set("fs.s3a.secret.key", accessKey) + SparkHadoopUtil.appendS3AndSparkHadoopConfigurations(conf, hadoopConf) + } - val sessionToken = System.getenv("AWS_SESSION_TOKEN") - if (sessionToken != null) { - hadoopConf.set("fs.s3a.session.token", sessionToken) - } - } - // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - conf.getAll.foreach { case (key, value) => - if (key.startsWith("spark.hadoop.")) { - hadoopConf.set(key.substring("spark.hadoop.".length), value) - } - } - val bufferSize = conf.get("spark.buffer.size", "65536") - hadoopConf.set("io.file.buffer.size", bufferSize) + /** + * Appends spark.hadoop.* configurations from a [[SparkConf]] to a Hadoop + * configuration without the spark.hadoop. prefix. + */ + def appendSparkHadoopConfigs(conf: SparkConf, hadoopConf: Configuration): Unit = { + SparkHadoopUtil.appendSparkHadoopConfigs(conf, hadoopConf) + } + + /** + * Appends spark.hadoop.* configurations from a Map to another without the spark.hadoop. prefix. + */ + def appendSparkHadoopConfigs( + srcMap: Map[String, String], + destMap: HashMap[String, String]): Unit = { + // Copy any "spark.hadoop.foo=bar" system properties into destMap as "foo=bar" + for ((key, value) <- srcMap if key.startsWith("spark.hadoop.")) { + destMap.put(key.substring("spark.hadoop.".length), value) } } @@ -114,9 +109,7 @@ class SparkHadoopUtil extends Logging { * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { - val hadoopConf = new Configuration() - appendS3AndSparkHadoopConfigurations(conf, hadoopConf) - hadoopConf + SparkHadoopUtil.newConfiguration(conf) } /** @@ -127,30 +120,55 @@ class SparkHadoopUtil extends Logging { def isYarnMode(): Boolean = { false } - def getCurrentUserCredentials(): Credentials = { null } - - def addCurrentUserCredentials(creds: Credentials) {} - def addSecretKeyToUserCredentials(key: String, secret: String) {} def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } - def loginUserFromKeytab(principalName: String, keytabFilename: String) { - UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) + def getCurrentUserCredentials(): Credentials = { + UserGroupInformation.getCurrentUser().getCredentials() + } + + def addCurrentUserCredentials(creds: Credentials): Unit = { + UserGroupInformation.getCurrentUser.addCredentials(creds) + } + + def loginUserFromKeytab(principalName: String, keytabFilename: String): Unit = { + if (!new File(keytabFilename).exists()) { + throw new SparkException(s"Keytab file: ${keytabFilename} does not exist") + } else { + logInfo("Attempting to login to Kerberos" + + s" using principal: ${principalName} and keytab: ${keytabFilename}") + UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) + } } /** * Returns a function that can be called to find Hadoop FileSystem bytes read. If * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will * return the bytes read on r since t. - * - * @return None if the required method can't be found. */ private[spark] def getFSBytesReadOnThreadCallback(): () => Long = { - val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics) - val f = () => threadStats.map(_.getBytesRead).sum - val baselineBytesRead = f() - () => f() - baselineBytesRead + val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum + val baseline = (Thread.currentThread().getId, f()) + + /** + * This function may be called in both spawned child threads and parent task thread (in + * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics. + * So we need a map to track the bytes read from the child threads and parent thread, + * summing them together to get the bytes read of this task. + */ + new Function0[Long] { + private val bytesReadMap = new mutable.HashMap[Long, Long]() + + override def apply(): Long = { + bytesReadMap.synchronized { + bytesReadMap.put(Thread.currentThread().getId, f()) + bytesReadMap.map { case (k, v) => + v - (if (k == baseline._1) baseline._2 else 0) + }.sum + } + } + } } /** @@ -211,6 +229,10 @@ class SparkHadoopUtil extends Logging { def globPath(pattern: Path): Seq[Path] = { val fs = pattern.getFileSystem(conf) + globPath(fs, pattern) + } + + def globPath(fs: FileSystem, pattern: Path): Seq[Path] = { Option(fs.globStatus(pattern)).map { statuses => statuses.map(_.getPath.makeQualified(fs.getUri, fs.getWorkingDirectory)).toSeq }.getOrElse(Seq.empty[Path]) @@ -220,6 +242,10 @@ class SparkHadoopUtil extends Logging { if (isGlobPath(pattern)) globPath(pattern) else Seq(pattern) } + def globPathIfNecessary(fs: FileSystem, pattern: Path): Seq[Path] = { + if (isGlobPath(pattern)) globPath(fs, pattern) else Seq(pattern) + } + /** * Lists all the files in a directory with the specified prefix, and does not end with the * given suffix. The returned {{FileStatus}} instances are sorted by the modification times of @@ -321,7 +347,7 @@ class SparkHadoopUtil extends Logging { if (credentials != null) { credentials.getAllTokens.asScala.map(tokenToString) } else { - Seq() + Seq.empty } } @@ -376,6 +402,21 @@ class SparkHadoopUtil extends Logging { s"${if (status.isDirectory) "d" else "-"}$perm") false } + + def serialize(creds: Credentials): Array[Byte] = { + val byteStream = new ByteArrayOutputStream + val dataStream = new DataOutputStream(byteStream) + creds.writeTokenStorageToStream(dataStream) + byteStream.toByteArray + } + + def deserialize(tokenBytes: Array[Byte]): Credentials = { + val tokensBuf = new ByteArrayInputStream(tokenBytes) + + val creds = new Credentials() + creds.readTokenStorageStream(new DataInputStream(tokensBuf)) + creds + } } object SparkHadoopUtil { @@ -411,4 +452,50 @@ object SparkHadoopUtil { hadoop } } + + /** + * Returns a Configuration object with Spark configuration applied on top. Unlike + * the instance method, this will always return a Configuration instance, and not a + * cluster manager-specific type. + */ + private[spark] def newConfiguration(conf: SparkConf): Configuration = { + val hadoopConf = new Configuration() + appendS3AndSparkHadoopConfigurations(conf, hadoopConf) + hadoopConf + } + + private def appendS3AndSparkHadoopConfigurations( + conf: SparkConf, + hadoopConf: Configuration): Unit = { + // Note: this null check is around more than just access to the "conf" object to maintain + // the behavior of the old implementation of this code, for backwards compatibility. + if (conf != null) { + // Explicitly check for S3 environment variables + val keyId = System.getenv("AWS_ACCESS_KEY_ID") + val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") + if (keyId != null && accessKey != null) { + hadoopConf.set("fs.s3.awsAccessKeyId", keyId) + hadoopConf.set("fs.s3n.awsAccessKeyId", keyId) + hadoopConf.set("fs.s3a.access.key", keyId) + hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey) + hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey) + hadoopConf.set("fs.s3a.secret.key", accessKey) + + val sessionToken = System.getenv("AWS_SESSION_TOKEN") + if (sessionToken != null) { + hadoopConf.set("fs.s3a.session.token", sessionToken) + } + } + appendSparkHadoopConfigs(conf, hadoopConf) + val bufferSize = conf.get("spark.buffer.size", "65536") + hadoopConf.set("io.file.buffer.size", bufferSize) + } + } + + private def appendSparkHadoopConfigs(conf: SparkConf, hadoopConf: Configuration): Unit = { + // Copy any "spark.hadoop.foo=bar" spark properties into conf as "foo=bar" + for ((key, value) <- conf.getAll if key.startsWith("spark.hadoop.")) { + hadoopConf.set(key.substring("spark.hadoop.".length), value) + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 77005aa9040b..286a4379d204 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{File, IOException} +import java.io._ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL import java.security.PrivilegedExceptionAction @@ -25,11 +25,13 @@ import java.text.ParseException import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import scala.util.Properties +import scala.util.{Properties, Try} import org.apache.commons.lang3.StringUtils -import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.ivy.Ivy import org.apache.ivy.core.LogOptions import org.apache.ivy.core.module.descriptor._ @@ -45,6 +47,8 @@ import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBibl import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util._ @@ -63,7 +67,9 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit extends CommandLineUtils { +object SparkSubmit extends CommandLineUtils with Logging { + + import DependencyUtils._ // Cluster managers private val YARN = 1 @@ -107,6 +113,10 @@ object SparkSubmit extends CommandLineUtils { // scalastyle:on println override def main(args: Array[String]): Unit = { + // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to + // be reset before the application starts. + val uninitLog = initializeLogIfNecessary(true, silent = true) + val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { // scalastyle:off println @@ -114,7 +124,7 @@ object SparkSubmit extends CommandLineUtils { // scalastyle:on println } appArgs.action match { - case SparkSubmitAction.SUBMIT => submit(appArgs) + case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog) case SparkSubmitAction.KILL => kill(appArgs) case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) } @@ -147,7 +157,7 @@ object SparkSubmit extends CommandLineUtils { * main class. */ @tailrec - private def submit(args: SparkSubmitArguments): Unit = { + private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = { val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) def doRunMain(): Unit = { @@ -179,11 +189,16 @@ object SparkSubmit extends CommandLineUtils { } } - // In standalone cluster mode, there are two submission gateways: - // (1) The traditional RPC gateway using o.a.s.deploy.Client as a wrapper - // (2) The new REST-based gateway introduced in Spark 1.3 - // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over - // to use the legacy gateway if the master endpoint turns out to be not a REST server. + // Let the main class re-initialize the logging system once it starts. + if (uninitLog) { + Logging.uninitialize() + } + + // In standalone cluster mode, there are two submission gateways: + // (1) The traditional RPC gateway using o.a.s.deploy.Client as a wrapper + // (2) The new REST-based gateway introduced in Spark 1.3 + // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over + // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { // scalastyle:off println @@ -196,7 +211,7 @@ object SparkSubmit extends CommandLineUtils { printWarning(s"Master endpoint ${args.master} was not a REST server. " + "Falling back to legacy submission gateway instead.") args.useRest = false - submit(args) + submit(args, false) } // In all other modes, just run the main class as prepared } else { @@ -206,14 +221,20 @@ object SparkSubmit extends CommandLineUtils { /** * Prepare the environment for submitting an application. - * This returns a 4-tuple: - * (1) the arguments for the child process, - * (2) a list of classpath entries for the child, - * (3) a map of system properties, and - * (4) the main class for the child + * + * @param args the parsed SparkSubmitArguments used for environment preparation. + * @param conf the Hadoop Configuration, this argument will only be set in unit test. + * @return a 4-tuple: + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a map of system properties, and + * (4) the main class for the child + * * Exposed for testing. */ - private[deploy] def prepareSubmitEnvironment(args: SparkSubmitArguments) + private[deploy] def prepareSubmitEnvironment( + args: SparkSubmitArguments, + conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], Map[String, String], String) = { // Return values val childArgs = new ArrayBuffer[String]() @@ -267,6 +288,25 @@ object SparkSubmit extends CommandLineUtils { } } + // Fail fast, the following modes are not supported or applicable + (clusterManager, deployMode) match { + case (STANDALONE, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") + case (LOCAL, CLUSTER) => + printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") + case (_, CLUSTER) if isShell(args.primaryResource) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + case (_, CLUSTER) if isSqlShell(args.mainClass) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + case (_, CLUSTER) if isThriftServer(args.mainClass) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") + case _ => + } + // Update args.deployMode if it is null. It will be passed down as a Spark property later. (args.deployMode, deployMode) match { case (null, CLIENT) => args.deployMode = "client" @@ -275,76 +315,103 @@ object SparkSubmit extends CommandLineUtils { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER - - // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files - // too for packages that include Python code - val exclusions: Seq[String] = - if (!StringUtils.isBlank(args.packagesExclusions)) { - args.packagesExclusions.split(",") - } else { - Nil + val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER + + if (!isMesosCluster && !isStandAloneCluster) { + // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files + // too for packages that include Python code + val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) + + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) + if (args.isPython) { + args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) + } } - // Create the IvySettings, either load from file or build defaults - val ivySettings = args.sparkProperties.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(args.repositories), - Option(args.ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(args.repositories), Option(args.ivyRepoPath)) - } - - val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, - ivySettings, exclusions = exclusions) - if (!StringUtils.isBlank(resolvedMavenCoordinates)) { - args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) - if (args.isPython) { - args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) + // install any R packages that may have been passed through --jars or --packages. + // Spark Packages may contain R source code inside the jar. + if (args.isR && !StringUtils.isBlank(args.jars)) { + RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) } } - // install any R packages that may have been passed through --jars or --packages. - // Spark Packages may contain R source code inside the jar. - if (args.isR && !StringUtils.isBlank(args.jars)) { - RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) + val sparkConf = new SparkConf(false) + args.sparkProperties.foreach { case (k, v) => sparkConf.set(k, v) } + val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) + val targetDir = Utils.createTempDir() + + // Resolve glob path for different resources. + args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull + args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull + args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + + // In client mode, download remote files. + var localPrimaryResource: String = null + var localJars: String = null + var localPyFiles: String = null + if (deployMode == CLIENT) { + // This security manager will not need an auth secret, but set a dummy value in case + // spark.authenticate is enabled, otherwise an exception is thrown. + sparkConf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + val secMgr = new SecurityManager(sparkConf) + localPrimaryResource = Option(args.primaryResource).map { + downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr) + }.orNull + localJars = Option(args.jars).map { + downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + }.orNull + localPyFiles = Option(args.pyFiles).map { + downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + }.orNull } - // Require all python files to be local, so we can add them to the PYTHONPATH - // In YARN cluster mode, python files are distributed as regular files, which can be non-local. - // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. - if (args.isPython && !isYarnCluster && !isMesosCluster) { - if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}") - } - val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") - if (nonLocalPyFiles.nonEmpty) { - printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles") + // When running in YARN, for some remote resources with scheme: + // 1. Hadoop FileSystem doesn't support them. + // 2. We explicitly bypass Hadoop FileSystem with "spark.yarn.dist.forceDownloadSchemes". + // We will download them to local disk prior to add to YARN's distributed cache. + // For yarn client mode, since we already download them with above code, so we only need to + // figure out the local path and replace the remote one. + if (clusterManager == YARN) { + sparkConf.setIfMissing(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + val secMgr = new SecurityManager(sparkConf) + val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) + + def shouldDownload(scheme: String): Boolean = { + forceDownloadSchemes.contains(scheme) || + Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure } - } - // Require all R files to be local - if (args.isR && !isYarnCluster && !isMesosCluster) { - if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") + def downloadResource(resource: String): String = { + val uri = Utils.resolveURI(resource) + uri.getScheme match { + case "local" | "file" => resource + case e if shouldDownload(e) => + val file = new File(targetDir, new Path(uri).getName) + if (file.exists()) { + file.toURI.toString + } else { + downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + } + case _ => uri.toString + } } - } - // The following modes are not supported or applicable - (clusterManager, deployMode) match { - case (STANDALONE, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + - "applications on standalone clusters.") - case (STANDALONE, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + - "applications on standalone clusters.") - case (LOCAL, CLUSTER) => - printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") - case (_, CLUSTER) if isShell(args.primaryResource) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") - case (_, CLUSTER) if isSqlShell(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") - case (_, CLUSTER) if isThriftServer(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") - case _ => + args.primaryResource = Option(args.primaryResource).map { downloadResource }.orNull + args.files = Option(args.files).map { files => + Utils.stringToSeq(files).map(downloadResource).mkString(",") + }.orNull + args.pyFiles = Option(args.pyFiles).map { pyFiles => + Utils.stringToSeq(pyFiles).map(downloadResource).mkString(",") + }.orNull + args.jars = Option(args.jars).map { jars => + Utils.stringToSeq(jars).map(downloadResource).mkString(",") + }.orNull + args.archives = Option(args.archives).map { archives => + Utils.stringToSeq(archives).map(downloadResource).mkString(",") + }.orNull } // If we're running a python app, set the main class to our specific python runner @@ -355,7 +422,7 @@ object SparkSubmit extends CommandLineUtils { // If a python file is provided, add it to the child arguments and list of files to deploy. // Usage: PythonAppRunner

[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" - args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs + args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs if (clusterManager != YARN) { // The YARN backend distributes the primary file differently, so don't merge it. args.files = mergeFileLists(args.files, args.primaryResource) @@ -365,8 +432,8 @@ object SparkSubmit extends CommandLineUtils { // The YARN backend handles python files differently, so don't merge the lists. args.files = mergeFileLists(args.files, args.pyFiles) } - if (args.pyFiles != null) { - sysProps("spark.submit.pyFiles") = args.pyFiles + if (localPyFiles != null) { + sysProps("spark.submit.pyFiles") = localPyFiles } } @@ -420,7 +487,7 @@ object SparkSubmit extends CommandLineUtils { // If an R file is provided, add it to the child arguments and list of files to deploy. // Usage: RRunner
[app arguments] args.mainClass = "org.apache.spark.deploy.RRunner" - args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.childArgs = ArrayBuffer(localPrimaryResource) ++ args.childArgs args.files = mergeFileLists(args.files, args.primaryResource) } } @@ -453,10 +520,19 @@ object SparkSubmit extends CommandLineUtils { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), + // Propagate attributes for dependency resolution at the driver side + OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.packages"), + OptionAssigner(args.repositories, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.jars.repositories"), + OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.ivy"), + OptionAssigner(args.packagesExclusions, STANDALONE | MESOS, + CLUSTER, sysProp = "spark.jars.excludes"), + // Yarn only OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.instances"), + OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.pyFiles"), OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"), OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"), @@ -480,15 +556,28 @@ object SparkSubmit extends CommandLineUtils { sysProp = "spark.driver.cores"), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, sysProp = "spark.driver.supervise"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy") + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + + // An internal option used only for spark-shell to add user jars to repl's classloader, + // previously it uses "spark.jars" or "spark.yarn.dist.jars" which now may be pointed to + // remote jars, so adding a new option to only specify local jars for spark-shell internally. + OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.repl.local.jars") ) // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath - // Also add the main application jar and any added jars to classpath in case YARN client - // requires these jars. - if (deployMode == CLIENT || isYarnCluster) { + if (deployMode == CLIENT) { childMainClass = args.mainClass + if (localPrimaryResource != null && isUserJar(localPrimaryResource)) { + childClasspath += localPrimaryResource + } + if (localJars != null) { childClasspath ++= localJars.split(",") } + } + // Add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + // This assumes both primaryResource and user jars are local jars, otherwise it will not be + // added to the classpath of YARN client. + if (isYarnCluster) { if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } @@ -545,31 +634,28 @@ object SparkSubmit extends CommandLineUtils { if (args.isPython) { sysProps.put("spark.yarn.isPython", "true") } - - if (args.pyFiles != null) { - sysProps("spark.submit.pyFiles") = args.pyFiles - } } // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL) { + if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { if (args.principal != null) { - require(args.keytab != null, "Keytab must be specified when principal is specified") - if (!new File(args.keytab).exists()) { - throw new SparkException(s"Keytab file: ${args.keytab} does not exist") - } else { + if (args.keytab != null) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") // Add keytab and principal configurations in sysProps to make them available // for later use; e.g. in spark sql, the isolated class loader used to talk // to HiveMetastore will use these settings. They will be set as Java system // properties and then loaded by SparkConf sysProps.put("spark.yarn.keytab", args.keytab) sysProps.put("spark.yarn.principal", args.principal) - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) } } } + if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { + setRMPrincipal(sysProps) + } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" @@ -654,6 +740,18 @@ object SparkSubmit extends CommandLineUtils { (childArgs, childClasspath, sysProps, childMainClass) } + // [SPARK-20328]. HadoopRDD calls into a Hadoop library that fetches delegation tokens with + // renewer set to the YARN ResourceManager. Since YARN isn't configured in Mesos mode, we + // must trick it into thinking we're YARN. + private def setRMPrincipal(sysProps: HashMap[String, String]): Unit = { + val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName + val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" + // scalastyle:off println + printStream.println(s"Setting ${key} to ${shortUserName}") + // scalastyle:off println + sysProps.put(key, shortUserName) + } + /** * Run the main method of the child class using the provided launch environment. * @@ -754,7 +852,7 @@ object SparkSubmit extends CommandLineUtils { } } - private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { + private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { val uri = Utils.resolveURI(localJar) uri.getScheme match { case "file" | "local" => @@ -819,12 +917,13 @@ object SparkSubmit extends CommandLineUtils { * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. */ - private def mergeFileLists(lists: String*): String = { + private[deploy] def mergeFileLists(lists: String*): String = { val merged = lists.filterNot(StringUtils.isBlank) .flatMap(_.split(",")) .mkString(",") if (merged == "") null else merged } + } /** Provides utility functions to be used inside SparkSubmit. */ @@ -833,6 +932,15 @@ private[spark] object SparkSubmitUtils { // Exposed for testing var printStream = SparkSubmit.printStream + // Exposed for testing. + // These components are used to make the default exclusion rules for Spark dependencies. + // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka-0-8 and + // other spark-streaming utility components. Underscore is there to differentiate between + // spark-streaming_2.1x and spark-streaming-kafka-0-8-assembly_2.1x + val IVY_DEFAULT_EXCLUDES = Seq("catalyst_", "core_", "graphx_", "kvstore_", "launcher_", "mllib_", + "mllib-local_", "network-common_", "network-shuffle_", "repl_", "sketch_", "sql_", "streaming_", + "tags_", "unsafe_") + /** * Represents a Maven Coordinate * @param groupId the groupId of the coordinate @@ -961,13 +1069,7 @@ private[spark] object SparkSubmitUtils { // Add scala exclusion rule md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName)) - // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka-0-8 and - // other spark-streaming utility components. Underscore is there to differentiate between - // spark-streaming_2.1x and spark-streaming-kafka-0-8-assembly_2.1x - val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", - "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") - - components.foreach { comp => + IVY_DEFAULT_EXCLUDES.foreach { comp => md.addExcludeRule(createExclusion(s"org.apache.spark:spark-$comp*:*", ivySettings, ivyConfName)) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0144fd1056ba..a7722e4f8602 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -27,11 +27,14 @@ import java.util.jar.JarFile import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source +import scala.util.Try import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.launcher.SparkSubmitArgumentsParser +import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils + /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. @@ -184,6 +187,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull + repositories = Option(repositories) + .orElse(sparkProperties.get("spark.jars.repositories")).orNull deployMode = Option(deployMode) .orElse(sparkProperties.get("spark.submit.deployMode")) .orElse(env.get("DEPLOY_MODE")) @@ -202,11 +207,12 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S uriScheme match { case "file" => try { - val jar = new JarFile(uri.getPath) - // Note that this might still return null if no main-class is set; we catch that later - mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class") + Utils.tryWithResource(new JarFile(uri.getPath)) { jar => + // Note that this might still return null if no main-class is set; we catch that later + mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class") + } } catch { - case e: Exception => + case _: Exception => SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource") } case _ => @@ -253,6 +259,23 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } + if (driverMemory != null + && Try(JavaUtils.byteStringAsBytes(driverMemory)).getOrElse(-1L) <= 0) { + SparkSubmit.printErrorAndExit("Driver Memory must be a positive number") + } + if (executorMemory != null + && Try(JavaUtils.byteStringAsBytes(executorMemory)).getOrElse(-1L) <= 0) { + SparkSubmit.printErrorAndExit("Executor Memory cores must be a positive number") + } + if (executorCores != null && Try(executorCores.toInt).getOrElse(-1) <= 0) { + SparkSubmit.printErrorAndExit("Executor cores must be a positive number") + } + if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { + SparkSubmit.printErrorAndExit("Total executor cores must be a positive number") + } + if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { + SparkSubmit.printErrorAndExit("Number of executors must be a positive number") + } if (pyFiles != null && !isPython) { SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") } @@ -482,7 +505,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println("Unknown/unsupported param " + unknownParam) } val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( - """Usage: spark-submit [options] [app arguments] + """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] |Usage: spark-submit --status [submission ID] --master [spark://...] |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) @@ -492,13 +515,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println( s""" |Options: - | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. + | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local + | (Default: local[*]). | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or | on one of the worker machines inside the cluster ("cluster") | (Default: client). | --class CLASS_NAME Your application's main class (for Java / Scala apps). | --name NAME A name of your application. - | --jars JARS Comma-separated list of local jars to include on the driver + | --jars JARS Comma-separated list of jars to include on the driver | and executor classpaths. | --packages Comma-separated list of maven coordinates of jars to include | on the driver and executor classpaths. Will search the local @@ -536,8 +560,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --verbose, -v Print additional debug output. | --version, Print the version of current Spark. | - | Spark standalone with cluster deploy mode only: - | --driver-cores NUM Cores for driver (Default: 1). + | Cluster deploy mode only: + | --driver-cores NUM Number of cores used by the driver, only in cluster mode + | (Default: 1). | | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. @@ -552,8 +577,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | or all available cores on the worker in standalone mode) | | YARN-only: - | --driver-cores NUM Number of cores used by the driver, only in cluster mode - | (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | If dynamic allocation is enabled, the initial number of diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 93f58ce63799..757c930b84eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -182,6 +182,10 @@ private[spark] class StandaloneAppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) } + case WorkerRemoved(id, host, message) => + logInfo("Master removed worker %s: %s".format(id, message)) + listener.workerRemoved(id, host, message) + case MasterChanged(masterRef, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) master = Some(masterRef) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala index 64255ec92b72..d8bc1a883def 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala @@ -18,9 +18,9 @@ package org.apache.spark.deploy.client /** - * Callbacks invoked by deploy client when various events happen. There are currently four events: - * connecting to the cluster, disconnecting, being given an executor, and having an executor - * removed (either due to failure or due to revocation). + * Callbacks invoked by deploy client when various events happen. There are currently five events: + * connecting to the cluster, disconnecting, being given an executor, having an executor removed + * (either due to failure or due to revocation), and having a worker removed. * * Users of this API should *not* block inside the callback methods. */ @@ -38,4 +38,6 @@ private[spark] trait StandaloneAppClientListener { def executorRemoved( fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit + + def workerRemoved(workerId: String, host: String, message: String): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 6d8758a3d3b1..5cb48ca3e60b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -30,7 +30,8 @@ private[spark] case class ApplicationAttemptInfo( endTime: Long, lastUpdated: Long, sparkUser: String, - completed: Boolean = false) + completed: Boolean = false, + appSparkVersion: String) private[spark] case class ApplicationHistoryInfo( id: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f4235df24512..3889dd097ee5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,14 +17,17 @@ package org.apache.spark.deploy.history -import java.io.{FileNotFoundException, IOException, OutputStream} -import java.util.UUID +import java.io.{File, FileNotFoundException, IOException} +import java.util.{Date, UUID} import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.xml.Node +import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude} +import com.fasterxml.jackson.module.scala.DefaultScalaModule import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, Path} @@ -35,11 +38,14 @@ import org.apache.hadoop.security.AccessControlException import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ +import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} +import org.apache.spark.util.kvstore._ /** * A class that provides application history from event logs stored in the file system. @@ -50,11 +56,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any * entries in the log dir whose modification time is greater than the last scan time - * are considered new or updated. These are replayed to create a new [[FsApplicationAttemptInfo]] - * entry and update or create a matching [[FsApplicationHistoryInfo]] element in the list - * of applications. + * are considered new or updated. These are replayed to create a new attempt info entry + * and update or create a matching application info element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the - * [[FsApplicationAttemptInfo]] is replaced by another one with a larger log size. + * attempt is replaced by another one with a larger log size. * - When [[updateProbe()]] is invoked to check if a loaded [[SparkUI]] * instance is out of date, the log size of the cached instance is checked against the app last * loaded by [[checkForLogs]]. @@ -78,6 +83,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) this(conf, new SystemClock()) } + import config._ import FsHistoryProvider._ // Interval between safemode checks. @@ -94,8 +100,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val NUM_PROCESSING_THREADS = conf.getInt(SPARK_HISTORY_FS_NUM_REPLAY_THREADS, Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt) - private val logDir = conf.getOption("spark.history.fs.logDirectory") - .getOrElse(DEFAULT_LOG_DIR) + private val logDir = conf.get(EVENT_LOG_DIR) private val HISTORY_UI_ACLS_ENABLE = conf.getBoolean("spark.history.ui.acls.enable", false) private val HISTORY_UI_ADMIN_ACLS = conf.get("spark.history.ui.admin.acls", "") @@ -117,17 +122,38 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // used for logging msgs (logs are re-scanned based on file size, rather than modtime) private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) - // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted - // into the map in order, so the LinkedHashMap maintains the correct ordering. - @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] - = new mutable.LinkedHashMap() + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) - val fileToAppInfo = new mutable.HashMap[Path, FsApplicationAttemptInfo]() + private val storePath = conf.get(LOCAL_STORE_DIR) - // List of application logs to be deleted by event log cleaner. - private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + // Visible for testing. + private[history] val listing: KVStore = storePath.map { path => + val dbPath = new File(path, "listing.ldb") - private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + def openDB(): LevelDB = new LevelDB(dbPath, new KVStoreScalaSerializer()) + + try { + val db = openDB() + val meta = db.getMetadata(classOf[KVStoreMetadata]) + + if (meta == null) { + db.setMetadata(new KVStoreMetadata(CURRENT_LISTING_VERSION, logDir)) + db + } else if (meta.version != CURRENT_LISTING_VERSION || !logDir.equals(meta.logDir)) { + logInfo("Detected mismatched config in existing DB, deleting...") + db.close() + Utils.deleteRecursively(dbPath) + openDB() + } else { + db + } + } catch { + case _: UnsupportedStoreVersionException => + logInfo("Detected incompatible DB versions, deleting...") + Utils.deleteRecursively(dbPath) + openDB() + } + }.getOrElse(new InMemoryStore()) /** * Return a runnable that performs the given operation on the event logs. @@ -218,21 +244,36 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (!conf.contains("spark.testing")) { // A task that periodically checks for event log updates on disk. logDebug(s"Scheduling update thread every $UPDATE_INTERVAL_S seconds") - pool.scheduleWithFixedDelay(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) + pool.scheduleWithFixedDelay( + getRunner(() => checkForLogs()), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. - pool.scheduleWithFixedDelay(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) + pool.scheduleWithFixedDelay( + getRunner(() => cleanLogs()), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) } } else { logDebug("Background update thread disabled for testing") } } - override def getListing(): Iterator[FsApplicationHistoryInfo] = applications.values.iterator + override def getListing(): Iterator[ApplicationHistoryInfo] = { + // Return the listing in end time descending order. + listing.view(classOf[ApplicationInfoWrapper]) + .index("endTime") + .reverse() + .iterator() + .asScala + .map(_.toAppHistoryInfo()) + } - override def getApplicationInfo(appId: String): Option[FsApplicationHistoryInfo] = { - applications.get(appId) + override def getApplicationInfo(appId: String): Option[ApplicationHistoryInfo] = { + try { + Some(load(appId).toAppHistoryInfo()) + } catch { + case e: NoSuchElementException => + None + } } override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() @@ -241,40 +282,40 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { - applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => + val appInfo = load(appId) + appInfo.attempts + .find(_.info.attemptId == attemptId) + .map { attempt => val replayBus = new ReplayListenerBus() val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, - HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.info.name, + HistoryServer.getAttemptURI(appId, attempt.info.attemptId), + Some(attempt.info.lastUpdated.getTime()), attempt.info.startTime.getTime()) // Do not call ui.bind() to avoid creating a new server for each application } val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) - - if (appListener.appId.isDefined) { - ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) - // make sure to set admin acls before view acls so they are properly picked up - val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") - ui.getSecurityManager.setAdminAcls(adminAcls) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, appListener.viewAcls.getOrElse("")) - val adminAclsGroups = HISTORY_UI_ADMIN_ACLS_GROUPS + "," + - appListener.adminAclsGroups.getOrElse("") - ui.getSecurityManager.setAdminAclsGroups(adminAclsGroups) - ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) - Some(LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize))) - } else { - None - } - + assert(appListener.appId.isDefined) + ui.appSparkVersion = appListener.appSparkVersion.getOrElse("") + ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) + // make sure to set admin acls before view acls so they are properly picked up + val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") + ui.getSecurityManager.setAdminAcls(adminAcls) + ui.getSecurityManager.setViewAcls(attempt.info.sparkUser, + appListener.viewAcls.getOrElse("")) + val adminAclsGroups = HISTORY_UI_ADMIN_ACLS_GROUPS + "," + + appListener.adminAclsGroups.getOrElse("") + ui.getSecurityManager.setAdminAclsGroups(adminAclsGroups) + ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) + LoadedAppUI(ui, () => updateProbe(appId, attemptId, attempt.fileSize)) } - } } catch { - case e: FileNotFoundException => None + case _: FileNotFoundException => None + case _: NoSuchElementException => None } } @@ -299,9 +340,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } override def stop(): Unit = { - if (initThread != null && initThread.isAlive()) { - initThread.interrupt() - initThread.join() + try { + if (initThread != null && initThread.isAlive()) { + initThread.interrupt() + initThread.join() + } + } finally { + listing.close() } } @@ -314,24 +359,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { val newLastScanTime = getNewLastScanTime() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") - val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) - .getOrElse(Seq[FileStatus]()) // scan for modified applications, replay and merge them - val logInfos: Seq[FileStatus] = statusList + val logInfos = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) !entry.isDirectory() && // FsHistoryProvider generates a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - prevFileSize < entry.getLen() && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) && + recordedFileSize(entry.getPath()) < entry.getLen() } - .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => - entry1.getModificationTime() >= entry2.getModificationTime() - } + entry1.getModificationTime() > entry2.getModificationTime() + } if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") @@ -419,205 +460,104 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - applications.get(appId) match { - case Some(appInfo) => - try { - // If no attempt is specified, or there is no attemptId for attempts, return all attempts - appInfo.attempts.filter { attempt => - attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get - }.foreach { attempt => - val logPath = new Path(logDir, attempt.logPath) - zipFileToStream(logPath, attempt.logPath, zipStream) - } - } finally { - zipStream.close() + val app = try { + load(appId) + } catch { + case _: NoSuchElementException => + throw new SparkException(s"Logs for $appId not found.") + } + + try { + // If no attempt is specified, or there is no attemptId for attempts, return all attempts + attemptId + .map { id => app.attempts.filter(_.info.attemptId == Some(id)) } + .getOrElse(app.attempts) + .map(_.logPath) + .foreach { log => + zipFileToStream(new Path(logDir, log), log, zipStream) } - case None => throw new SparkException(s"Logs for $appId not found.") + } finally { + zipStream.close() } } /** - * Replay the log files in the list and merge the list of old applications with new ones + * Replay the given log file, saving the application in the listing db. */ protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { - val newAttempts = try { - val eventsFilter: ReplayEventsFilter = { eventString => - eventString.startsWith(APPL_START_EVENT_PREFIX) || - eventString.startsWith(APPL_END_EVENT_PREFIX) - } - - val logPath = fileStatus.getPath() - val appCompleted = isApplicationCompleted(fileStatus) - - // Use loading time as lastUpdated since some filesystems don't update modifiedTime - // each time file is updated. However use modifiedTime for completed jobs so lastUpdated - // won't change whenever HistoryServer restarts and reloads the file. - val lastUpdated = if (appCompleted) fileStatus.getModificationTime else clock.getTimeMillis() - - val appListener = replay(fileStatus, appCompleted, new ReplayListenerBus(), eventsFilter) - - // Without an app ID, new logs will render incorrectly in the listing page, so do not list or - // try to show their UI. - if (appListener.appId.isDefined) { - val attemptInfo = new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - lastUpdated, - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted, - fileStatus.getLen() - ) - fileToAppInfo(logPath) = attemptInfo - logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") - Some(attemptInfo) - } else { - logWarning(s"Failed to load application log ${fileStatus.getPath}. " + - "The application may have not started.") - None - } - - } catch { - case e: Exception => - logError( - s"Exception encountered when attempting to load application log ${fileStatus.getPath}", - e) - None - } - - if (newAttempts.isEmpty) { - return + val eventsFilter: ReplayEventsFilter = { eventString => + eventString.startsWith(APPL_START_EVENT_PREFIX) || + eventString.startsWith(APPL_END_EVENT_PREFIX) || + eventString.startsWith(LOG_START_EVENT_PREFIX) } - // Build a map containing all apps that contain new attempts. The app information in this map - // contains both the new app attempt, and those that were already loaded in the existing apps - // map. If an attempt has been updated, it replaces the old attempt in the list. - val newAppMap = new mutable.HashMap[String, FsApplicationHistoryInfo]() - - applications.synchronized { - newAttempts.foreach { attempt => - val appInfo = newAppMap.get(attempt.appId) - .orElse(applications.get(attempt.appId)) - .map { app => - val attempts = - app.attempts.filter(_.attemptId != attempt.attemptId) ++ List(attempt) - new FsApplicationHistoryInfo(attempt.appId, attempt.name, - attempts.sortWith(compareAttemptInfo)) - } - .getOrElse(new FsApplicationHistoryInfo(attempt.appId, attempt.name, List(attempt))) - newAppMap(attempt.appId) = appInfo - } + val logPath = fileStatus.getPath() + logInfo(s"Replaying log path: $logPath") - // Merge the new app list with the existing one, maintaining the expected ordering (descending - // end time). Maintaining the order is important to avoid having to sort the list every time - // there is a request for the log list. - val newApps = newAppMap.values.toSeq.sortWith(compareAppInfo) - val mergedApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() - def addIfAbsent(info: FsApplicationHistoryInfo): Unit = { - if (!mergedApps.contains(info.id)) { - mergedApps += (info.id -> info) - } - } + val bus = new ReplayListenerBus() + val listener = new AppListingListener(fileStatus, clock) + bus.addListener(listener) - val newIterator = newApps.iterator.buffered - val oldIterator = applications.values.iterator.buffered - while (newIterator.hasNext && oldIterator.hasNext) { - if (newAppMap.contains(oldIterator.head.id)) { - oldIterator.next() - } else if (compareAppInfo(newIterator.head, oldIterator.head)) { - addIfAbsent(newIterator.next()) - } else { - addIfAbsent(oldIterator.next()) - } - } - newIterator.foreach(addIfAbsent) - oldIterator.foreach(addIfAbsent) - - applications = mergedApps - } + replay(fileStatus, isApplicationCompleted(fileStatus), bus, eventsFilter) + listener.applicationInfo.foreach(addListing) + listing.write(LogInfo(logPath.toString(), fileStatus.getLen())) } /** * Delete event logs from the log directory according to the clean policy defined by the user. */ private[history] def cleanLogs(): Unit = { + var iterator: Option[KVStoreIterator[ApplicationInfoWrapper]] = None try { - val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000 - - val now = clock.getTimeMillis() - val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() - - def shouldClean(attempt: FsApplicationAttemptInfo): Boolean = { - now - attempt.lastUpdated > maxAge - } + val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 + + // Iterate descending over all applications whose oldest attempt happened before maxTime. + iterator = Some(listing.view(classOf[ApplicationInfoWrapper]) + .index("oldestAttempt") + .reverse() + .first(maxTime) + .closeableIterator()) + + iterator.get.asScala.foreach { app => + // Applications may have multiple attempts, some of which may not need to be deleted yet. + val (remaining, toDelete) = app.attempts.partition { attempt => + attempt.info.lastUpdated.getTime() >= maxTime + } - // Scan all logs from the log directory. - // Only completed applications older than the specified max age will be deleted. - applications.values.foreach { app => - val (toClean, toRetain) = app.attempts.partition(shouldClean) - attemptsToClean ++= toClean - - if (toClean.isEmpty) { - appsToRetain += (app.id -> app) - } else if (toRetain.nonEmpty) { - appsToRetain += (app.id -> - new FsApplicationHistoryInfo(app.id, app.name, toRetain.toList)) + if (remaining.nonEmpty) { + val newApp = new ApplicationInfoWrapper(app.info, remaining) + listing.write(newApp) } - } - applications = appsToRetain + toDelete.foreach { attempt => + val logPath = new Path(logDir, attempt.logPath) + try { + listing.delete(classOf[LogInfo], logPath.toString()) + } catch { + case _: NoSuchElementException => + logDebug(s"Log info entry for $logPath not found.") + } + try { + fs.delete(logPath, true) + } catch { + case e: AccessControlException => + logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") + case t: IOException => + logError(s"IOException in cleaning ${attempt.logPath}", t) + } + } - val leftToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] - attemptsToClean.foreach { attempt => - try { - fs.delete(new Path(logDir, attempt.logPath), true) - } catch { - case e: AccessControlException => - logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") - case t: IOException => - logError(s"IOException in cleaning ${attempt.logPath}", t) - leftToClean += attempt + if (remaining.isEmpty) { + listing.delete(app.getClass(), app.id) } } - - attemptsToClean = leftToClean } catch { - case t: Exception => logError("Exception in cleaning logs", t) + case t: Exception => logError("Exception while cleaning logs", t) + } finally { + iterator.foreach(_.close()) } } - /** - * Comparison function that defines the sort order for the application listing. - * - * @return Whether `i1` should precede `i2`. - */ - private def compareAppInfo( - i1: FsApplicationHistoryInfo, - i2: FsApplicationHistoryInfo): Boolean = { - val a1 = i1.attempts.head - val a2 = i2.attempts.head - if (a1.endTime != a2.endTime) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime - } - - /** - * Comparison function that defines the sort order for application attempts within the same - * application. Order is: attempts are sorted by descending start time. - * Most recent attempt state matches with current state of the app. - * - * Normally applications should have a single running attempt; but failure to call sc.stop() - * may cause multiple running attempts to show up. - * - * @return Whether `a1` should precede `a2`. - */ - private def compareAttemptInfo( - a1: FsApplicationAttemptInfo, - a2: FsApplicationAttemptInfo): Boolean = { - a1.startTime >= a2.startTime - } - /** * Replays the events in the specified log file on the supplied `ReplayListenerBus`. Returns * an `ApplicationEventListener` instance with event data captured from the replay. @@ -642,6 +582,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appListener = new ApplicationEventListener bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted, eventsFilter) + logInfo(s"Finished replaying $logPath") appListener } finally { logInput.close() @@ -678,26 +619,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * @return a summary of the component state */ override def toString: String = { - val header = s""" - | FsHistoryProvider: logdir=$logDir, - | last scan time=$lastScanTime - | Cached application count =${applications.size}} - """.stripMargin - val sb = new StringBuilder(header) - applications.foreach(entry => sb.append(entry._2).append("\n")) - sb.toString - } - - /** - * Look up an application attempt - * @param appId application ID - * @param attemptId Attempt ID, if set - * @return the matching attempt, if found - */ - def lookup(appId: String, attemptId: Option[String]): Option[FsApplicationAttemptInfo] = { - applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId) - } + val count = listing.count(classOf[ApplicationInfoWrapper]) + s"""|FsHistoryProvider{logdir=$logDir, + | storedir=$storePath, + | last scan time=$lastScanTime + | application count=$count}""".stripMargin } /** @@ -715,72 +641,215 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appId: String, attemptId: Option[String], prevFileSize: Long)(): Boolean = { - lookup(appId, attemptId) match { - case None => + try { + val attempt = getAttempt(appId, attemptId) + val logPath = fs.makeQualified(new Path(logDir, attempt.logPath)) + recordedFileSize(logPath) > prevFileSize + } catch { + case _: NoSuchElementException => logDebug(s"Application Attempt $appId/$attemptId not found") false - case Some(latest) => - prevFileSize < latest.fileSize } } -} -private[history] object FsHistoryProvider { - val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + /** + * Return the last known size of the given event log, recorded the last time the file + * system scanner detected a change in the file. + */ + private def recordedFileSize(log: Path): Long = { + try { + listing.read(classOf[LogInfo], log.toString()).fileSize + } catch { + case _: NoSuchElementException => 0L + } + } + + private def load(appId: String): ApplicationInfoWrapper = { + listing.read(classOf[ApplicationInfoWrapper], appId) + } + + /** + * Write the app's information to the given store. Serialized to avoid the (notedly rare) case + * where two threads are processing separate attempts of the same application. + */ + private def addListing(app: ApplicationInfoWrapper): Unit = listing.synchronized { + val attempt = app.attempts.head + + val oldApp = try { + load(app.id) + } catch { + case _: NoSuchElementException => + app + } + + def compareAttemptInfo(a1: AttemptInfoWrapper, a2: AttemptInfoWrapper): Boolean = { + a1.info.startTime.getTime() > a2.info.startTime.getTime() + } + + val attempts = oldApp.attempts.filter(_.info.attemptId != attempt.info.attemptId) ++ + List(attempt) + + val newAppInfo = new ApplicationInfoWrapper( + app.info, + attempts.sortWith(compareAttemptInfo)) + listing.write(newAppInfo) + } - private val NOT_STARTED = "" + /** For testing. Returns internal data about a single attempt. */ + private[history] def getAttempt(appId: String, attemptId: Option[String]): AttemptInfoWrapper = { + load(appId).attempts.find(_.info.attemptId == attemptId).getOrElse( + throw new NoSuchElementException(s"Cannot find attempt $attemptId of $appId.")) + } +} + +private[history] object FsHistoryProvider { private val SPARK_HISTORY_FS_NUM_REPLAY_THREADS = "spark.history.fs.numReplayThreads" private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" private val APPL_END_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationEnd\"" + + private val LOG_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerLogStart\"" + + /** + * Current version of the data written to the listing database. When opening an existing + * db, if the version does not match this value, the FsHistoryProvider will throw away + * all data and re-generate the listing data from the event logs. + */ + private[history] val CURRENT_LISTING_VERSION = 1L } /** - * Application attempt information. - * - * @param logPath path to the log file, or, for a legacy log, its directory - * @param name application name - * @param appId application ID - * @param attemptId optional attempt ID - * @param startTime start time (from playback) - * @param endTime end time (from playback). -1 if the application is incomplete. - * @param lastUpdated the modification time of the log file when this entry was built by replaying - * the history. - * @param sparkUser user running the application - * @param completed flag to indicate whether or not the application has completed. - * @param fileSize the size of the log file the last time the file was scanned for changes + * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as + * the API serializer. */ -private class FsApplicationAttemptInfo( +private class KVStoreScalaSerializer extends KVStoreSerializer { + + mapper.registerModule(DefaultScalaModule) + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + mapper.setDateFormat(v1.JacksonMessageWriter.makeISODateFormat) + +} + +private[history] case class KVStoreMetadata( + version: Long, + logDir: String) + +private[history] case class LogInfo( + @KVIndexParam logPath: String, + fileSize: Long) + +private[history] class AttemptInfoWrapper( + val info: v1.ApplicationAttemptInfo, val logPath: String, - val name: String, - val appId: String, - attemptId: Option[String], - startTime: Long, - endTime: Long, - lastUpdated: Long, - sparkUser: String, - completed: Boolean, - val fileSize: Long) - extends ApplicationAttemptInfo( - attemptId, startTime, endTime, lastUpdated, sparkUser, completed) { - - /** extend the superclass string value with the extra attributes of this class */ - override def toString: String = { - s"FsApplicationAttemptInfo($name, $appId," + - s" ${super.toString}, source=$logPath, size=$fileSize" + val fileSize: Long) { + + def toAppAttemptInfo(): ApplicationAttemptInfo = { + ApplicationAttemptInfo(info.attemptId, info.startTime.getTime(), + info.endTime.getTime(), info.lastUpdated.getTime(), info.sparkUser, + info.completed, info.appSparkVersion) } + } -/** - * Application history information - * @param id application ID - * @param name application name - * @param attempts list of attempts, most recent first. - */ -private class FsApplicationHistoryInfo( - id: String, - override val name: String, - override val attempts: List[FsApplicationAttemptInfo]) - extends ApplicationHistoryInfo(id, name, attempts) +private[history] class ApplicationInfoWrapper( + val info: v1.ApplicationInfo, + val attempts: List[AttemptInfoWrapper]) { + + @JsonIgnore @KVIndexParam + def id: String = info.id + + @JsonIgnore @KVIndexParam("endTime") + def endTime(): Long = attempts.head.info.endTime.getTime() + + @JsonIgnore @KVIndexParam("oldestAttempt") + def oldestAttempt(): Long = attempts.map(_.info.lastUpdated.getTime()).min + + def toAppHistoryInfo(): ApplicationHistoryInfo = { + ApplicationHistoryInfo(info.id, info.name, attempts.map(_.toAppAttemptInfo())) + } + +} + +private[history] class AppListingListener(log: FileStatus, clock: Clock) extends SparkListener { + + private val app = new MutableApplicationInfo() + private val attempt = new MutableAttemptInfo(log.getPath().getName(), log.getLen()) + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + app.id = event.appId.orNull + app.name = event.appName + + attempt.attemptId = event.appAttemptId + attempt.startTime = new Date(event.time) + attempt.lastUpdated = new Date(clock.getTimeMillis()) + attempt.sparkUser = event.sparkUser + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { + attempt.endTime = new Date(event.time) + attempt.lastUpdated = new Date(log.getModificationTime()) + attempt.duration = event.time - attempt.startTime.getTime() + attempt.completed = true + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(sparkVersion) => + attempt.appSparkVersion = sparkVersion + case _ => + } + + def applicationInfo: Option[ApplicationInfoWrapper] = { + if (app.id != null) { + Some(app.toView()) + } else { + None + } + } + + private class MutableApplicationInfo { + var id: String = null + var name: String = null + var coresGranted: Option[Int] = None + var maxCores: Option[Int] = None + var coresPerExecutor: Option[Int] = None + var memoryPerExecutorMB: Option[Int] = None + + def toView(): ApplicationInfoWrapper = { + val apiInfo = new v1.ApplicationInfo(id, name, coresGranted, maxCores, coresPerExecutor, + memoryPerExecutorMB, Nil) + new ApplicationInfoWrapper(apiInfo, List(attempt.toView())) + } + + } + + private class MutableAttemptInfo(logPath: String, fileSize: Long) { + var attemptId: Option[String] = None + var startTime = new Date(-1) + var endTime = new Date(-1) + var lastUpdated = new Date(-1) + var duration = 0L + var sparkUser: String = null + var completed = false + var appSparkVersion = "" + + def toView(): AttemptInfoWrapper = { + val apiInfo = new v1.ApplicationAttemptInfo( + attemptId, + startTime, + endTime, + lastUpdated, + duration, + sparkUser, + completed, + appSparkVersion) + new AttemptInfoWrapper( + apiInfo, + logPath, + fileSize) + } + + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0e7a6c24d4fa..af1471763340 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -26,8 +26,9 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { + // stripXSS is called first to remove suspicious characters used in XSS attacks val requestedIncomplete = - Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean + Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala new file mode 100644 index 000000000000..fb9e997def0d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.util.concurrent.TimeUnit + +import scala.annotation.meta.getter + +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.util.kvstore.KVIndex + +private[spark] object config { + + /** Use this to annotate constructor params to be used as KVStore indices. */ + type KVIndexParam = KVIndex @getter + + val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + val EVENT_LOG_DIR = ConfigBuilder("spark.history.fs.logDirectory") + .stringConf + .createWithDefault(DEFAULT_LOG_DIR) + + val MAX_LOG_AGE_S = ConfigBuilder("spark.history.fs.cleaner.maxAge") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("7d") + + val LOCAL_STORE_DIR = ConfigBuilder("spark.history.store.path") + .doc("Local directory where to cache application history information. By default this is " + + "not set, meaning all history information will be kept in memory.") + .stringConf + .createOptional + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 816bf37e39fe..e030cac60a8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ import org.apache.spark.serializer.{JavaSerializer, Serializer} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, Utils} private[deploy] class Master( override val rpcEnv: RpcEnv, @@ -80,7 +80,7 @@ private[deploy] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(address.host, "Expected hostname") + Utils.checkHost(address.host) private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, @@ -133,6 +133,7 @@ private[deploy] class Master( masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort if (reverseProxy) { masterWebUiUrl = conf.get("spark.ui.reverseProxyUrl", masterWebUiUrl) + webUi.addProxy() logInfo(s"Spark Master is acting as a reverse proxy. Master, Workers and " + s"Applications UIs are available at $masterWebUiUrl") } @@ -231,7 +232,8 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) - case RegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -243,7 +245,7 @@ private[deploy] class Master( workerRef, workerWebUiUrl) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl)) + workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress)) schedule() } else { val workerAddress = worker.endpoint.address @@ -366,7 +368,7 @@ private[deploy] class Master( drivers.find(_.id == driverId).foreach { driver => driver.worker = Some(worker) driver.state = DriverState.RUNNING - worker.drivers(driverId) = driver + worker.addDriver(driver) } } case None => @@ -497,7 +499,7 @@ private[deploy] class Master( override def onDisconnected(address: RpcAddress): Unit = { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) + addressToWorker.get(address).foreach(removeWorker(_, s"${address} got disassociated")) addressToApp.get(address).foreach(finishApplication) if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } } @@ -543,9 +545,13 @@ private[deploy] class Master( state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. - workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) + workers.filter(_.state == WorkerState.UNKNOWN).foreach( + removeWorker(_, "Not responding for recovery")) apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) + // Update the state of recovered apps to RUNNING + apps.filter(_.state == ApplicationState.WAITING).foreach(_.state = ApplicationState.RUNNING) + // Reschedule drivers which were not claimed by any workers drivers.filter(_.worker.isEmpty).foreach { d => logWarning(s"Driver ${d.id} was not found after master recovery") @@ -654,19 +660,22 @@ private[deploy] class Master( private def startExecutorsOnWorkers(): Unit = { // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. - for (app <- waitingApps if app.coresLeft > 0) { - val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor - // Filter out workers that don't have enough resources to launch an executor - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && - worker.coresFree >= coresPerExecutor.getOrElse(1)) - .sortBy(_.coresFree).reverse - val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) - - // Now that we've decided how many cores to allocate on each worker, let's allocate them - for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { - allocateWorkerResourceToExecutors( - app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) + for (app <- waitingApps) { + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1) + // If the cores left is less than the coresPerExecutor,the cores left will not be allocated + if (app.coresLeft >= coresPerExecutor) { + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), app.desc.coresPerExecutor, usableWorkers(pos)) + } } } } @@ -751,7 +760,7 @@ private[deploy] class Master( if (oldWorker.state == WorkerState.UNKNOWN) { // A worker registering from UNKNOWN implies that the worker was restarted during recovery. // The old worker must thus be dead, so we will remove it and accept the new worker. - removeWorker(oldWorker) + removeWorker(oldWorker, "Worker replaced by a new worker with same address") } else { logInfo("Attempted to re-register worker at same address: " + workerAddress) return false @@ -761,20 +770,15 @@ private[deploy] class Master( workers += worker idToWorker(worker.id) = worker addressToWorker(workerAddress) = worker - if (reverseProxy) { - webUi.addProxyTargets(worker.id, worker.webUiAddress) - } true } - private def removeWorker(worker: WorkerInfo) { + private def removeWorker(worker: WorkerInfo, msg: String) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id addressToWorker -= worker.endpoint.address - if (reverseProxy) { - webUi.removeProxyTargets(worker.id) - } + for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) exec.application.driver.send(ExecutorUpdated( @@ -791,13 +795,27 @@ private[deploy] class Master( removeDriver(driver.id, DriverState.ERROR, None) } } + logInfo(s"Telling app of lost worker: " + worker.id) + apps.filterNot(completedApps.contains(_)).foreach { app => + app.driver.send(WorkerRemoved(worker.id, worker.host, msg)) + } persistenceEngine.removeWorker(worker) } private def relaunchDriver(driver: DriverInfo) { - driver.worker = None - driver.state = DriverState.RELAUNCHING - waitingDrivers += driver + // We must setup a new driver with a new driver id here, because the original driver may + // be still running. Consider this scenario: a worker is network partitioned with master, + // the master then relaunches driver driverID1 with a driver id driverID2, then the worker + // reconnects to master. From this point on, if driverID2 is equal to driverID1, then master + // can not distinguish the statusUpdate of the original driver and the newly relaunched one, + // for example, when DriverStateChanged(driverID1, KILLED) arrives at master, master will + // remove driverID1, so the newly relaunched driver disappears too. See SPARK-19900 for details. + removeDriver(driver.id, DriverState.RELAUNCHING, None) + val newDriver = createDriver(driver.desc) + persistenceEngine.addDriver(newDriver) + drivers.add(newDriver) + waitingDrivers += newDriver + schedule() } @@ -822,9 +840,6 @@ private[deploy] class Master( endpointToApp(app.driver) = app addressToApp(appAddress) = app waitingApps += app - if (reverseProxy) { - webUi.addProxyTargets(app.id, app.desc.appUiUrl) - } } private def finishApplication(app: ApplicationInfo) { @@ -838,9 +853,7 @@ private[deploy] class Master( idToApp -= app.id endpointToApp -= app.driver addressToApp -= app.driver.address - if (reverseProxy) { - webUi.removeProxyTargets(app.id) - } + if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach { a => @@ -965,7 +978,7 @@ private[deploy] class Master( if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( worker.id, WORKER_TIMEOUT_MS / 1000)) - removeWorker(worker) + removeWorker(worker, s"Not receiving heartbeat for ${WORKER_TIMEOUT_MS / 1000} seconds") } else { if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it @@ -1023,6 +1036,8 @@ private[deploy] object Master extends Logging { val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { + Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler( + exitOnUncaughtException = false)) Utils.initDaemon(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index c63793c16dce..615d2533cf08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -60,12 +60,12 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) exte @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 4e20c10fd142..c87d6e24b78c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -32,7 +32,7 @@ private[spark] class WorkerInfo( val webUiAddress: String) extends Serializable { - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index a8d721f3e0d4..68e57b7564ad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { - val appId = request.getParameter("appId") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val appId = UIUtils.stripXSS(request.getParameter("appId")) val state = master.askSync[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId) .getOrElse(state.completedApps.find(_.id == appId).orNull) @@ -99,11 +100,11 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
-

Executor Summary

+

Executor Summary ({allExecutors.length})

{executorsTable} { if (removedExecutors.nonEmpty) { -

Removed Executors

++ +

Removed Executors ({removedExecutors.length})

++ removedExecutorsTable } } @@ -124,10 +125,10 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") {executor.memory} {executor.state} - stdout - stderr + stdout + stderr } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 9351c72094e3..bc0bf6a1d970 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { if (parent.killEnabled && parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) { - val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean - val id = Option(request.getParameter("id")) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val killFlag = + Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean + val id = Option(UIUtils.stripXSS(request.getParameter("id"))) if (id.isDefined && killFlag) { action(id.get) } @@ -126,14 +128,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Workers

+

Workers ({workers.length})

{workerTable}
-

Running Applications

+

Running Applications ({activeApps.length})

{activeAppsTable}
@@ -142,7 +144,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {if (hasDrivers) {
-

Running Drivers

+

Running Drivers ({activeDrivers.length})

{activeDriversTable}
@@ -152,7 +154,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Completed Applications

+

Completed Applications ({completedApps.length})

{completedAppsTable}
@@ -162,7 +164,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { if (hasDrivers) {
-

Completed Drivers

+

Completed Drivers ({completedDrivers.length})

{completedDriversTable}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 8cfd0f682932..35b7ddd46e4d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -17,10 +17,7 @@ package org.apache.spark.deploy.master.ui -import scala.collection.mutable.HashMap - -import org.eclipse.jetty.servlet.ServletContextHandler - +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.Master import org.apache.spark.internal.Logging import org.apache.spark.ui.{SparkUI, WebUI} @@ -38,7 +35,6 @@ class MasterWebUI( val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) - private val proxyHandlers = new HashMap[String, ServletContextHandler] initialize() @@ -54,16 +50,19 @@ class MasterWebUI( "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST"))) } - def addProxyTargets(id: String, target: String): Unit = { - var endTarget = target.stripSuffix("/") - val handler = createProxyHandler("/proxy/" + id, endTarget) + def addProxy(): Unit = { + val handler = createProxyHandler(idToUiAddress) attachHandler(handler) - proxyHandlers(id) = handler } - def removeProxyTargets(id: String): Unit = { - proxyHandlers.remove(id).foreach(detachHandler) + def idToUiAddress(id: String): Option[String] = { + val state = masterEndpointRef.askSync[MasterStateResponse](RequestMasterState) + val maybeWorkerUiAddress = state.workers.find(_.id == id).map(_.webUiAddress) + val maybeAppUiAddress = state.activeApps.find(_.id == id).map(_.desc.appUiUrl) + + maybeWorkerUiAddress.orElse(maybeAppUiAddress) } + } private[master] object MasterWebUI { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index b30c980e95a9..e88195d95f27 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -11,7 +11,7 @@ * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and + * See the License for the specific language governing permissions and * limitations under the License. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 56620064c57f..0164084ab129 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -11,7 +11,7 @@ * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and + * See the License for the specific language governing permissions and * limitations under the License. */ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala similarity index 90% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 5adeb8e605ff..78b0e6b2cbf3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import scala.reflect.runtime.universe import scala.util.control.NonFatal @@ -28,11 +28,12 @@ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging { +private[security] class HBaseDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hbase" - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, sparkConf: SparkConf, creds: Credentials): Option[Long] = { @@ -55,7 +56,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide None } - override def credentialsRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos" } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala new file mode 100644 index 000000000000..c134b7ebe38f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging + +/** + * Manages all the registered HadoopDelegationTokenProviders and offer APIs for other modules to + * obtain delegation tokens and their renewal time. By default [[HadoopFSDelegationTokenProvider]], + * [[HiveDelegationTokenProvider]] and [[HBaseDelegationTokenProvider]] will be loaded in if not + * explicitly disabled. + * + * Also, each HadoopDelegationTokenProvider is controlled by + * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * enabled/disabled by the configuration spark.security.credentials.hive.enabled. + * + * @param sparkConf Spark configuration + * @param hadoopConf Hadoop configuration + * @param fileSystems Delegation tokens will be fetched for these Hadoop filesystems. + */ +private[spark] class HadoopDelegationTokenManager( + sparkConf: SparkConf, + hadoopConf: Configuration, + fileSystems: Configuration => Set[FileSystem]) + extends Logging { + + private val deprecatedProviderEnabledConfigs = List( + "spark.yarn.security.tokens.%s.enabled", + "spark.yarn.security.credentials.%s.enabled") + private val providerEnabledConfig = "spark.security.credentials.%s.enabled" + + // Maintain all the registered delegation token providers + private val delegationTokenProviders = getDelegationTokenProviders + logDebug(s"Using the following delegation token providers: " + + s"${delegationTokenProviders.keys.mkString(", ")}.") + + /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */ + def this(sparkConf: SparkConf, hadoopConf: Configuration) = { + this( + sparkConf, + hadoopConf, + hadoopConf => Set(FileSystem.get(hadoopConf).getHomeDirectory.getFileSystem(hadoopConf))) + } + + private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { + val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), + new HiveDelegationTokenProvider, + new HBaseDelegationTokenProvider) + + // Filter out providers for which spark.security.credentials.{service}.enabled is false. + providers + .filter { p => isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap + } + + def isServiceEnabled(serviceName: String): Boolean = { + val key = providerEnabledConfig.format(serviceName) + + deprecatedProviderEnabledConfigs.foreach { pattern => + val deprecatedKey = pattern.format(serviceName) + if (sparkConf.contains(deprecatedKey)) { + logWarning(s"${deprecatedKey} is deprecated. Please use ${key} instead.") + } + } + + val isEnabledDeprecated = deprecatedProviderEnabledConfigs.forall { pattern => + sparkConf + .getOption(pattern.format(serviceName)) + .map(_.toBoolean) + .getOrElse(true) + } + + sparkConf + .getOption(key) + .map(_.toBoolean) + .getOrElse(isEnabledDeprecated) + } + + /** + * Get delegation token provider for the specified service. + */ + def getServiceDelegationTokenProvider(service: String): Option[HadoopDelegationTokenProvider] = { + delegationTokenProviders.get(service) + } + + /** + * Writes delegation tokens to creds. Delegation tokens are fetched from all registered + * providers. + * + * @return Time after which the fetched delegation tokens should be renewed. + */ + def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Long = { + delegationTokenProviders.values.flatMap { provider => + if (provider.delegationTokensRequired(hadoopConf)) { + provider.obtainDelegationTokens(hadoopConf, sparkConf, creds) + } else { + logDebug(s"Service ${provider.serviceName} does not require a token." + + s" Check your configuration to see if security is disabled or not.") + None + } + }.foldLeft(Long.MaxValue)(math.min) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala new file mode 100644 index 000000000000..1ba245e84af4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf + +/** + * Hadoop delegation token provider. + */ +private[spark] trait HadoopDelegationTokenProvider { + + /** + * Name of the service to provide delegation tokens. This name should be unique. Spark will + * internally use this name to differentiate delegation token providers. + */ + def serviceName: String + + /** + * Returns true if delegation tokens are required for this service. By default, it is based on + * whether Hadoop security is enabled. + */ + def delegationTokensRequired(hadoopConf: Configuration): Boolean + + /** + * Obtain delegation tokens for this service and get the time of the next renewal. + * @param hadoopConf Configuration of current Hadoop Compatible system. + * @param creds Credentials to add tokens and security keys to. + * @return If the returned tokens are renewable and can be renewed, return the time of the next + * renewal, otherwise None should be returned. + */ + def obtainDelegationTokens( + hadoopConf: Configuration, + sparkConf: SparkConf, + creds: Credentials): Option[Long] +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala similarity index 61% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index f65c886db944..300773c58b18 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -15,75 +15,100 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import scala.collection.JavaConverters._ import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred.Master -import org.apache.hadoop.security.Credentials +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -private[security] class HadoopFSCredentialProvider - extends ServiceCredentialProvider with Logging { - // Token renewal interval, this value will be set in the first call, - // if None means no token renewer specified or no token can be renewed, - // so cannot get token renewal interval. +private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration => Set[FileSystem]) + extends HadoopDelegationTokenProvider with Logging { + + // This tokenRenewalInterval will be set in the first call to obtainDelegationTokens. + // If None, no token renewer is specified or no token can be renewed, + // so we cannot get the token renewal interval. private var tokenRenewalInterval: Option[Long] = null override val serviceName: String = "hadoopfs" - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, sparkConf: SparkConf, creds: Credentials): Option[Long] = { - // NameNode to access, used to get tokens from different FileSystems - val tmpCreds = new Credentials() - val tokenRenewer = getTokenRenewer(hadoopConf) - hadoopFSsToAccess(hadoopConf, sparkConf).foreach { dst => - val dstFs = dst.getFileSystem(hadoopConf) - logInfo("getting token for: " + dst) - dstFs.addDelegationTokens(tokenRenewer, tmpCreds) - } + + val fsToGetTokens = fileSystems(hadoopConf) + val fetchCreds = fetchDelegationTokens(getTokenRenewer(hadoopConf), fsToGetTokens, creds) // Get the token renewal interval if it is not set. It will only be called once. if (tokenRenewalInterval == null) { - tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf) + tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf, fsToGetTokens) } // Get the time of next renewal. val nextRenewalDate = tokenRenewalInterval.flatMap { interval => - val nextRenewalDates = tmpCreds.getAllTokens.asScala + val nextRenewalDates = fetchCreds.getAllTokens.asScala .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) - .map { t => - val identifier = t.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] + .map { token => + val identifier = token + .decodeIdentifier() + .asInstanceOf[AbstractDelegationTokenIdentifier] identifier.getIssueDate + interval } if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) } - creds.addAll(tmpCreds) nextRenewalDate } + def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + UserGroupInformation.isSecurityEnabled + } + + private def getTokenRenewer(hadoopConf: Configuration): String = { + val tokenRenewer = Master.getMasterPrincipal(hadoopConf) + logDebug("Delegation token renewer is: " + tokenRenewer) + + if (tokenRenewer == null || tokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer." + logError(errorMessage) + throw new SparkException(errorMessage) + } + + tokenRenewer + } + + private def fetchDelegationTokens( + renewer: String, + filesystems: Set[FileSystem], + creds: Credentials): Credentials = { + + filesystems.foreach { fs => + logInfo("getting token for: " + fs) + fs.addDelegationTokens(renewer, creds) + } + + creds + } + private def getTokenRenewalInterval( - hadoopConf: Configuration, sparkConf: SparkConf): Option[Long] = { + hadoopConf: Configuration, + sparkConf: SparkConf, + filesystems: Set[FileSystem]): Option[Long] = { // We cannot use the tokens generated with renewer yarn. Trying to renew // those will fail with an access control issue. So create new tokens with the logged in // user as renewer. sparkConf.get(PRINCIPAL).flatMap { renewer => val creds = new Credentials() - hadoopFSsToAccess(hadoopConf, sparkConf).foreach { dst => - val dstFs = dst.getFileSystem(hadoopConf) - dstFs.addDelegationTokens(renewer, creds) - } + fetchDelegationTokens(renewer, filesystems, creds) val renewIntervals = creds.getAllTokens.asScala.filter { _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier] @@ -99,22 +124,4 @@ private[security] class HadoopFSCredentialProvider if (renewIntervals.isEmpty) None else Some(renewIntervals.min) } } - - private def getTokenRenewer(conf: Configuration): String = { - val delegTokenRenewer = Master.getMasterPrincipal(conf) - logDebug("delegation token renewer is: " + delegTokenRenewer) - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - - delegTokenRenewer - } - - private def hadoopFSsToAccess(hadoopConf: Configuration, sparkConf: SparkConf): Set[Path] = { - sparkConf.get(FILESYSTEMS_TO_ACCESS).map(new Path(_)).toSet + - sparkConf.get(STAGING_DIR).map(new Path(_)) - .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory) - } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala similarity index 55% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index 16d8fc32bb42..b31cc595ed83 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -15,16 +15,17 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import java.lang.reflect.UndeclaredThrowableException import java.security.PrivilegedExceptionAction -import scala.reflect.runtime.universe import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token @@ -33,79 +34,72 @@ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[security] class HiveCredentialProvider extends ServiceCredentialProvider with Logging { +private[security] class HiveDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" + private val classNotFoundErrorStr = s"You are attempting to use the " + + s"${getClass.getCanonicalName}, but your Spark distribution is not built with Hive libraries." + private def hiveConf(hadoopConf: Configuration): Configuration = { try { - val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) - // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down - // to a Configuration and used without reflection - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - // using the (Configuration, Class) constructor allows the current configuration to be - // included in the hive config. - val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration], - classOf[Object].getClass) - ctor.newInstance(hadoopConf, hiveConfClass).asInstanceOf[Configuration] + new HiveConf(hadoopConf, classOf[HiveConf]) } catch { case NonFatal(e) => logDebug("Fail to create Hive Configuration", e) hadoopConf + case e: NoClassDefFoundError => + logWarning(classNotFoundErrorStr) + hadoopConf } } - override def credentialsRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { UserGroupInformation.isSecurityEnabled && hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty } - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, sparkConf: SparkConf, creds: Credentials): Option[Long] = { - val conf = hiveConf(hadoopConf) - - val principalKey = "hive.metastore.kerberos.principal" - val principal = conf.getTrimmed(principalKey, "") - require(principal.nonEmpty, s"Hive principal $principalKey undefined") - val metastoreUri = conf.getTrimmed("hive.metastore.uris", "") - require(metastoreUri.nonEmpty, "Hive metastore uri undefined") - - val currentUser = UserGroupInformation.getCurrentUser() - logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + - s"$principal at $metastoreUri") + try { + val conf = hiveConf(hadoopConf) - val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) - val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - val closeCurrent = hiveClass.getMethod("closeCurrent") + val principalKey = "hive.metastore.kerberos.principal" + val principal = conf.getTrimmed(principalKey, "") + require(principal.nonEmpty, s"Hive principal $principalKey undefined") + val metastoreUri = conf.getTrimmed("hive.metastore.uris", "") + require(metastoreUri.nonEmpty, "Hive metastore uri undefined") - try { - // get all the instance methods before invoking any - val getDelegationToken = hiveClass.getMethod("getDelegationToken", - classOf[String], classOf[String]) - val getHive = hiveClass.getMethod("get", hiveConfClass) + val currentUser = UserGroupInformation.getCurrentUser() + logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + + s"$principal at $metastoreUri") doAsRealUser { - val hive = getHive.invoke(null, conf) - val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal) - .asInstanceOf[String] + val hive = Hive.get(conf, classOf[HiveConf]) + val tokenStr = hive.getDelegationToken(currentUser.getUserName(), principal) + val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) logInfo(s"Get Token from hive metastore: ${hive2Token.toString}") creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) } + + None } catch { case NonFatal(e) => - logDebug(s"Fail to get token from service $serviceName", e) + logDebug(s"Failed to get token from service $serviceName", e) + None + case e: NoClassDefFoundError => + logWarning(classNotFoundErrorStr) + None } finally { Utils.tryLogNonFatalError { - closeCurrent.invoke(null) + Hive.closeCurrent() } } - - None } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index cba4aaffe2ca..12e0dae3f5e5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -44,7 +44,7 @@ object CommandUtils extends Logging { memory: Int, sparkHome: String, substituteArguments: String => String, - classPaths: Seq[String] = Seq[String](), + classPaths: Seq[String] = Seq.empty, env: Map[String, String] = sys.env): ProcessBuilder = { val localCommand = buildLocalCommand( command, securityMgr, substituteArguments, classPaths, env) @@ -73,7 +73,7 @@ object CommandUtils extends Logging { command: Command, securityMgr: SecurityManager, substituteArguments: String => String, - classPath: Seq[String] = Seq[String](), + classPath: Seq[String] = Seq.empty, env: Map[String, String]): Command = { val libraryPathName = Utils.libraryPathEnvName val libraryPathEntries = command.libraryPathEntries @@ -96,7 +96,7 @@ object CommandUtils extends Logging { command.arguments.map(substituteArguments), newEnvironment, command.classPathEntries ++ classPath, - Seq[String](), // library path already captured in environment variable + Seq.empty, // library path already captured in environment variable // filter out auth secret from java options command.javaOpts.filterNot(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index e878c10183f6..58a181128eb4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -57,7 +57,8 @@ private[deploy] class DriverRunner( @volatile private[worker] var finalException: Option[Exception] = None // Timeout to wait for when trying to terminate a driver. - private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000 + private val DRIVER_TERMINATE_TIMEOUT_MS = + conf.getTimeAsMs("spark.worker.driverTerminateTimeout", "10s") // Decoupled for testing def setClock(_clock: Clock): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 6799f78ec0c1..c1671192e0c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,7 +19,10 @@ package org.apache.spark.deploy.worker import java.io.File +import org.apache.commons.lang3.StringUtils + import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -51,6 +54,7 @@ object DriverWrapper { new MutableURLClassLoader(Array(userJarUrl), currentLoader) } Thread.currentThread.setContextClassLoader(loader) + setupDependencies(loader, userJar) // Delegate to supplied main class val clazz = Utils.classForName(mainClass) @@ -66,4 +70,28 @@ object DriverWrapper { System.exit(-1) } } + + private def setupDependencies(loader: MutableURLClassLoader, userJar: String): Unit = { + val sparkConf = new SparkConf() + val secMgr = new SecurityManager(sparkConf) + val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf) + + val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = + Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") + .map(sys.props.get(_).orNull) + + val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, + packages, repositories, ivyRepoPath) + val jars = { + val jarsProp = sys.props.get("spark.jars").orNull + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates) + } else { + jarsProp + } + } + val localJars = DependencyUtils.resolveAndDownloadJars(jars, userJar, sparkConf, hadoopConf, + secMgr) + DependencyUtils.addJarsToClassPath(localJars, loader) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 00b9d1af373d..ed5fa4b839cd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -38,7 +38,7 @@ import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.internal.Logging import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, Utils} private[deploy] class Worker( override val rpcEnv: RpcEnv, @@ -55,7 +55,7 @@ private[deploy] class Worker( private val host = rpcEnv.address.host private val port = rpcEnv.address.port - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) // A scheduled executor used to send messages at the specified time. @@ -99,6 +99,20 @@ private[deploy] class Worker( private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None + + /** + * Whether to use the master address in `masterRpcAddresses` if possible. If it's disabled, Worker + * will just use the address received from Master. + */ + private val preferConfiguredMasterAddress = + conf.getBoolean("spark.worker.preferConfiguredMasterAddress", false) + /** + * The master address to connect in case of failure. When the connection is broken, worker will + * use this address to connect. This is usually just one of `masterRpcAddresses`. However, when + * a master is restarted or takes over leadership, it will be an address sent from master, which + * may not be in `masterRpcAddresses`. + */ + private var masterAddressToConnect: Option[RpcAddress] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" private var workerWebUiUrl: String = "" @@ -141,6 +155,8 @@ private[deploy] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) + val reverseProxy = conf.getBoolean("spark.ui.reverseProxy", false) + private var registerMasterFutures: Array[JFuture[_]] = null private var registrationRetryTimer: Option[JScheduledFuture[_]] = None @@ -196,13 +212,22 @@ private[deploy] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { + /** + * Change to use the new master. + * + * @param masterRef the new master ref + * @param uiUrl the new master Web UI address + * @param masterAddress the new master address which the worker should use to connect in case of + * failure + */ + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) { // activeMasterUrl it's a valid Spark url since we receive it from master. activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl + masterAddressToConnect = Some(masterAddress) master = Some(masterRef) connected = true - if (conf.getBoolean("spark.ui.reverseProxy", false)) { + if (reverseProxy) { logInfo(s"WorkerWebUI is available at $activeMasterWebUiUrl/proxy/$workerId") } // Cancel any outstanding re-registration attempts because we found a new master @@ -266,7 +291,8 @@ private[deploy] class Worker( if (registerMasterFutures != null) { registerMasterFutures.foreach(_.cancel(true)) } - val masterAddress = masterRef.address + val masterAddress = + if (preferConfiguredMasterAddress) masterAddressToConnect.get else masterRef.address registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { override def run(): Unit = { try { @@ -342,15 +368,27 @@ private[deploy] class Worker( } private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = { - masterEndpoint.send(RegisterWorker(workerId, host, port, self, cores, memory, workerWebUiUrl)) + masterEndpoint.send(RegisterWorker( + workerId, + host, + port, + self, + cores, + memory, + workerWebUiUrl, + masterEndpoint.address)) } private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { msg match { - case RegisteredWorker(masterRef, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + case RegisteredWorker(masterRef, masterWebUiUrl, masterAddress) => + if (preferConfiguredMasterAddress) { + logInfo("Successfully registered with master " + masterAddress.toSparkURL) + } else { + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + } registered = true - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterAddress) forwordMessageScheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { self.send(SendHeartbeat) @@ -412,14 +450,13 @@ private[deploy] class Worker( } }(cleanupThreadExecutor) - cleanupFuture.onFailure { - case e: Throwable => - logError("App dir cleanup failed: " + e.getMessage, e) - }(cleanupThreadExecutor) + cleanupFuture.failed.foreach(e => + logError("App dir cleanup failed: " + e.getMessage, e) + )(cleanupThreadExecutor) case MasterChanged(masterRef, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterRef.address) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) @@ -561,7 +598,8 @@ private[deploy] class Worker( } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (master.exists(_.address == remoteAddress)) { + if (master.exists(_.address == remoteAddress) || + masterAddressToConnect.exists(_ == remoteAddress)) { logInfo(s"$remoteAddress Disassociated !") masterDisconnected() } @@ -583,10 +621,9 @@ private[deploy] class Worker( dirList.foreach { dir => Utils.deleteRecursively(new File(dir)) } - }(cleanupThreadExecutor).onFailure { - case e: Throwable => - logError(s"Clean up app dir $dirList failed: ${e.getMessage}", e) - }(cleanupThreadExecutor) + }(cleanupThreadExecutor).failed.foreach(e => + logError(s"Clean up app dir $dirList failed: ${e.getMessage}", e) + )(cleanupThreadExecutor) } shuffleService.applicationRemoved(id) } @@ -700,11 +737,24 @@ private[deploy] object Worker extends Logging { val ENDPOINT_NAME = "Worker" def main(argStrings: Array[String]) { + Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler( + exitOnUncaughtException = false)) Utils.initDaemon(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir, conf = conf) + // With external shuffle service enabled, if we request to launch multiple workers on one host, + // we can only successfully launch the first worker and the rest fails, because with the port + // bound, we may launch no more than one external shuffle service on each host. + // When this happens, we should give explicit reason of failure instead of fail silently. For + // more detail see SPARK-20989. + val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + val sparkWorkerInstances = scala.sys.env.getOrElse("SPARK_WORKER_INSTANCES", "1").toInt + require(externalShuffleServiceEnabled == false || sparkWorkerInstances <= 1, + "Starting multiple workers on one host is failed because we may launch no more than one " + + "external shuffle service on each host, please set spark.shuffle.service.enabled to " + + "false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict.") rpcEnv.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 777020d4d5c8..580281288b06 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -21,8 +21,8 @@ import java.lang.management.ManagementFactory import scala.annotation.tailrec -import org.apache.spark.util.{IntParam, MemoryParam, Utils} import org.apache.spark.SparkConf +import org.apache.spark.util.{IntParam, MemoryParam, Utils} /** * Command-line parser for the worker. @@ -68,12 +68,12 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 80dc9bf8779d..2f5a5642d3ca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -33,13 +33,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with private val supportedLogTypes = Set("stderr", "stdout") private val defaultBytes = 100 * 1024 + // stripXSS is called first to remove suspicious characters used in XSS attacks def renderLog(request: HttpServletRequest): String = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val logDir = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => @@ -55,13 +58,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with pre + logText } + // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val (logDir, params, pageName) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 1ad973122b60..ce84bc4dae32 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -23,8 +23,8 @@ import scala.xml.Node import org.json4s.JValue +import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} -import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.ui.{UIUtils, WebUIPage} @@ -51,9 +51,11 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { val driverHeaders = Seq("DriverID", "Main Class", "State", "Cores", "Memory", "Logs", "Notes") val runningDrivers = workerState.drivers.sortBy(_.driverId).reverse - val runningDriverTable = UIUtils.listingTable(driverHeaders, driverRow, runningDrivers) + val runningDriverTable = UIUtils.listingTable[DriverRunner](driverHeaders, + driverRow(workerState.workerId, _), runningDrivers) val finishedDrivers = workerState.finishedDrivers.sortBy(_.driverId).reverse - val finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers) + val finishedDriverTable = UIUtils.listingTable[DriverRunner](driverHeaders, + driverRow(workerState.workerId, _), finishedDrivers) // For now we only show driver information if the user has submitted drivers to the cluster. // This is until we integrate the notion of drivers and applications in the UI. @@ -102,6 +104,11 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { } def executorRow(executor: ExecutorRunner): Seq[Node] = { + val workerUrlRef = UIUtils.makeHref(parent.worker.reverseProxy, executor.workerId, + parent.webUrl) + val appUrlRef = UIUtils.makeHref(parent.worker.reverseProxy, executor.appId, + executor.appDesc.appUiUrl) + {executor.execId} {executor.cores} @@ -112,21 +119,30 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
  • ID: {executor.appId}
  • -
  • Name: {executor.appDesc.name}
  • +
  • Name: + { + if ({executor.state == ExecutorState.RUNNING} && executor.appDesc.appUiUrl.nonEmpty) { + {executor.appDesc.name} + } else { + {executor.appDesc.name} + } + } +
  • User: {executor.appDesc.user}
- stdout - stderr + stdout + stderr } - def driverRow(driver: DriverRunner): Seq[Node] = { + def driverRow(workerId: String, driver: DriverRunner): Seq[Node] = { + val workerUrlRef = UIUtils.makeHref(parent.worker.reverseProxy, workerId, parent.webUrl) {driver.driverId} {driver.driverDesc.command.arguments(2)} @@ -138,8 +154,8 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { {Utils.megabytesToString(driver.driverDesc.mem)} - stdout - stderr + stdout + stderr {driver.finalException.getOrElse("")} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b2b26ee107c0..d27362ae85be 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -163,9 +163,9 @@ private[spark] class CoarseGrainedExecutorBackend( if (notifyDriver && driver.nonEmpty) { driver.get.ask[Boolean]( RemoveExecutor(executorId, new ExecutorLossReason(reason)) - ).onFailure { case e => + ).failed.foreach(e => logWarning(s"Unable to notify the driver due to " + e.getMessage, e) - }(ThreadUtils.sameThread) + )(ThreadUtils.sameThread) } System.exit(code) @@ -191,11 +191,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf - val port = executorConf.getInt("spark.executor.port", 0) val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, - port, + -1, executorConf, new SecurityManager(executorConf), clientMode = true) @@ -220,8 +219,13 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { SparkHadoopUtil.get.startCredentialUpdater(driverConf) } + cfg.hadoopDelegationCreds.foreach { hadoopCreds => + val creds = SparkHadoopUtil.get.deserialize(hadoopCreds) + SparkHadoopUtil.get.addCurrentUserCredentials(creds) + } + val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false) + driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index 326e04241977..3e0d52cb4ccb 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import org.apache.spark.{TaskCommitDenied, TaskFailedReason} +import org.apache.spark.TaskCommitDenied /** * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver. @@ -29,5 +29,5 @@ private[spark] class CommitDeniedException( attemptNumber: Int) extends Exception(msg) { - def toTaskFailedReason: TaskFailedReason = TaskCommitDenied(jobID, splitID, attemptNumber) + def toTaskCommitDeniedReason: TaskCommitDenied = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 51b6c373c4da..2ecbb749d1fb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -56,7 +56,7 @@ private[spark] class Executor( env: SparkEnv, userClassPath: Seq[URL] = Nil, isLocal: Boolean = false, - uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler) + uncaughtExceptionHandler: UncaughtExceptionHandler = new SparkUncaughtExceptionHandler) extends Logging { logInfo(s"Starting executor ID $executorId on host $executorHostname") @@ -71,7 +71,7 @@ private[spark] class Executor( private val conf = env.conf // No ip or host:port - just hostname - Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname) // must not have port specified. assert (0 == Utils.parseHostPort(executorHostname)._2) @@ -113,8 +113,9 @@ private[spark] class Executor( private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]() if (!isLocal) { - env.metricsSystem.registerSource(executorSource) env.blockManager.initialize(conf.getAppId) + env.metricsSystem.registerSource(executorSource) + env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource) } // Whether to load classes in user jars before those in Spark jars @@ -130,6 +131,9 @@ private[spark] class Executor( // Set the classloader for serializer env.serializer.setDefaultClassLoader(replClassLoader) + // SPARK-21928. SerializerManager's internal instance of Kryo might get used in netty threads + // for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too. + env.serializerManager.setDefaultClassLoader(replClassLoader) // Max size of direct result. If task result is bigger than this, we use the block manager // to send the result back. @@ -322,8 +326,14 @@ private[spark] class Executor( throw new TaskKilledException(killReason.get) } - logDebug("Task " + taskId + "'s epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) + // The purpose of updating the epoch here is to invalidate executor map output status cache + // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be + // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so + // we don't need to make any special calls here. + if (!isLocal) { + logDebug("Task " + taskId + "'s epoch is " + task.epoch) + env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch) + } // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() @@ -425,6 +435,7 @@ private[spark] class Executor( } } + setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { @@ -456,9 +467,9 @@ private[spark] class Executor( taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) case CausedBy(cDE: CommitDeniedException) => - val reason = cDE.toTaskFailedReason + val reason = cDE.toTaskCommitDeniedReason setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) case t: Throwable => // Attempt to exit cleanly by informing the driver of our failure. @@ -466,29 +477,38 @@ private[spark] class Executor( // the default uncaught exception handler, which will terminate the Executor. logError(s"Exception in $taskName (TID $taskId)", t) - // Collect latest accumulator values to report back to the driver - val accums: Seq[AccumulatorV2[_, _]] = - if (task != null) { - task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) - task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) - task.collectAccumulatorUpdates(taskFailed = true) - } else { - Seq.empty - } + // SPARK-20904: Do not report failure to driver if if happened during shut down. Because + // libraries may set up shutdown hooks that race with running tasks during shutdown, + // spurious failures may occur and can result in improper accounting in the driver (e.g. + // the task failure would not be ignored if the shutdown happened because of premption, + // instead of an app issue). + if (!ShutdownHookManager.inShutdown()) { + // Collect latest accumulator values to report back to the driver + val accums: Seq[AccumulatorV2[_, _]] = + if (task != null) { + task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) + task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + task.collectAccumulatorUpdates(taskFailed = true) + } else { + Seq.empty + } - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) - val serializedTaskEndReason = { - try { - ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums)) - } catch { - case _: NotSerializableException => - // t is not serializable so just send the stacktrace - ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums)) + val serializedTaskEndReason = { + try { + ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums)) + } } + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) + } else { + logInfo("Not reporting error to driver during JVM shutdown.") } - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 8dd1a1ea059b..4be395c8358b 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -31,6 +31,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[executor] val _remoteBlocksFetched = new LongAccumulator private[executor] val _localBlocksFetched = new LongAccumulator private[executor] val _remoteBytesRead = new LongAccumulator + private[executor] val _remoteBytesReadToDisk = new LongAccumulator private[executor] val _localBytesRead = new LongAccumulator private[executor] val _fetchWaitTime = new LongAccumulator private[executor] val _recordsRead = new LongAccumulator @@ -50,6 +51,11 @@ class ShuffleReadMetrics private[spark] () extends Serializable { */ def remoteBytesRead: Long = _remoteBytesRead.sum + /** + * Total number of remotes bytes read to disk from the shuffle by this task. + */ + def remoteBytesReadToDisk: Long = _remoteBytesReadToDisk.sum + /** * Shuffle data that was read from the local disk (as opposed to from a remote executor). */ @@ -80,6 +86,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[spark] def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched.add(v) private[spark] def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched.add(v) private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v) + private[spark] def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk.add(v) private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v) private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v) private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) @@ -87,6 +94,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[spark] def setRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.setValue(v) private[spark] def setLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.setValue(v) private[spark] def setRemoteBytesRead(v: Long): Unit = _remoteBytesRead.setValue(v) + private[spark] def setRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk.setValue(v) private[spark] def setLocalBytesRead(v: Long): Unit = _localBytesRead.setValue(v) private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v) private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) @@ -99,6 +107,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { _remoteBlocksFetched.setValue(0) _localBlocksFetched.setValue(0) _remoteBytesRead.setValue(0) + _remoteBytesReadToDisk.setValue(0) _localBytesRead.setValue(0) _fetchWaitTime.setValue(0) _recordsRead.setValue(0) @@ -106,6 +115,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { _remoteBlocksFetched.add(metric.remoteBlocksFetched) _localBlocksFetched.add(metric.localBlocksFetched) _remoteBytesRead.add(metric.remoteBytesRead) + _remoteBytesReadToDisk.add(metric.remoteBytesReadToDisk) _localBytesRead.add(metric.localBytesRead) _fetchWaitTime.add(metric.fetchWaitTime) _recordsRead.add(metric.recordsRead) @@ -122,6 +132,7 @@ private[spark] class TempShuffleReadMetrics { private[this] var _remoteBlocksFetched = 0L private[this] var _localBlocksFetched = 0L private[this] var _remoteBytesRead = 0L + private[this] var _remoteBytesReadToDisk = 0L private[this] var _localBytesRead = 0L private[this] var _fetchWaitTime = 0L private[this] var _recordsRead = 0L @@ -129,6 +140,7 @@ private[spark] class TempShuffleReadMetrics { def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v + def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v def incLocalBytesRead(v: Long): Unit = _localBytesRead += v def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v def incRecordsRead(v: Long): Unit = _recordsRead += v @@ -136,6 +148,7 @@ private[spark] class TempShuffleReadMetrics { def remoteBlocksFetched: Long = _remoteBlocksFetched def localBlocksFetched: Long = _localBlocksFetched def remoteBytesRead: Long = _remoteBytesRead + def remoteBytesReadToDisk: Long = _remoteBytesReadToDisk def localBytesRead: Long = _localBytesRead def fetchWaitTime: Long = _fetchWaitTime def recordsRead: Long = _recordsRead diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index a3ce3d1ccc5e..85b2745a2aec 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -112,6 +112,12 @@ class TaskMetrics private[spark] () extends Serializable { /** * Storage statuses of any blocks that have been updated as a result of this task. + * + * Tracking the _updatedBlockStatuses can use a lot of memory. + * It is not used anywhere inside of Spark so we would ideally remove it, but its exposed to + * the user in SparkListenerTaskEnd so the api is kept for compatibility. + * Tracking can be turned off to save memory via config + * TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES. */ def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = { // This is called on driver. All accumulator updates have a fixed value. So it's safe to use @@ -215,6 +221,7 @@ class TaskMetrics private[spark] () extends Serializable { shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched, shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched, shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead, + shuffleRead.REMOTE_BYTES_READ_TO_DISK -> shuffleReadMetrics._remoteBytesReadToDisk, shuffleRead.LOCAL_BYTES_READ -> shuffleReadMetrics._localBytesRead, shuffleRead.FETCH_WAIT_TIME -> shuffleReadMetrics._fetchWaitTime, shuffleRead.RECORDS_READ -> shuffleReadMetrics._recordsRead, diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 9606c4754314..17cdba4f1305 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,9 +27,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} -import org.apache.spark.internal.config import org.apache.spark.SparkContext import org.apache.spark.annotation.Since +import org.apache.spark.internal.config /** * A general format for reading whole files in as streams, byte arrays, diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index fa34f1e886c7..f47cd38d712c 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -21,11 +21,8 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.Text -import org.apache.hadoop.mapreduce.InputSplit -import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat -import org.apache.hadoop.mapreduce.RecordReader -import org.apache.hadoop.mapreduce.TaskAttemptContext /** * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala index c7f2847731fc..c0d709ad25f2 100644 --- a/core/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -96,47 +96,59 @@ trait Logging { } protected def initializeLogIfNecessary(isInterpreter: Boolean): Unit = { + initializeLogIfNecessary(isInterpreter, silent = false) + } + + protected def initializeLogIfNecessary( + isInterpreter: Boolean, + silent: Boolean = false): Boolean = { if (!Logging.initialized) { Logging.initLock.synchronized { if (!Logging.initialized) { - initializeLogging(isInterpreter) + initializeLogging(isInterpreter, silent) + return true } } } + false } - private def initializeLogging(isInterpreter: Boolean): Unit = { + private def initializeLogging(isInterpreter: Boolean, silent: Boolean): Unit = { // Don't use a logger in here, as this is itself occurring during initialization of a logger // If Log4j 1.2 is being used, but is not initialized, load a default properties file - val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr - // This distinguishes the log4j 1.2 binding, currently - // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently - // org.apache.logging.slf4j.Log4jLoggerFactory - val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) - if (usingLog4j12) { + if (Logging.isLog4j12()) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements // scalastyle:off println if (!log4j12Initialized) { + Logging.defaultSparkLog4jConfig = true val defaultLogProps = "org/apache/spark/log4j-defaults.properties" Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { case Some(url) => PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + if (!silent) { + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + } case None => System.err.println(s"Spark was unable to load $defaultLogProps") } } + val rootLogger = LogManager.getRootLogger() + if (Logging.defaultRootLevel == null) { + Logging.defaultRootLevel = rootLogger.getLevel() + } + if (isInterpreter) { // Use the repl's main class to define the default log level when running the shell, // overriding the root logger's config if they're different. - val rootLogger = LogManager.getRootLogger() val replLogger = LogManager.getLogger(logName) val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) if (replLevel != rootLogger.getEffectiveLevel()) { - System.err.printf("Setting default log level to \"%s\".\n", replLevel) - System.err.println("To adjust logging level use sc.setLogLevel(newLevel). " + - "For SparkR, use setLogLevel(newLevel).") + if (!silent) { + System.err.printf("Setting default log level to \"%s\".\n", replLevel) + System.err.println("To adjust logging level use sc.setLogLevel(newLevel). " + + "For SparkR, use setLogLevel(newLevel).") + } rootLogger.setLevel(replLevel) } } @@ -150,8 +162,11 @@ trait Logging { } } -private object Logging { +private[spark] object Logging { @volatile private var initialized = false + @volatile private var defaultRootLevel: Level = null + @volatile private var defaultSparkLog4jConfig = false + val initLock = new Object() try { // We use reflection here to handle the case where users remove the @@ -165,4 +180,29 @@ private object Logging { } catch { case e: ClassNotFoundException => // can't log anything yet so just fail silently } + + /** + * Marks the logging system as not initialized. This does a best effort at resetting the + * logging system to its initial state so that the next class to use logging triggers + * initialization again. + */ + def uninitialize(): Unit = initLock.synchronized { + if (isLog4j12()) { + if (defaultSparkLog4jConfig) { + defaultSparkLog4jConfig = false + LogManager.resetConfiguration() + } else { + LogManager.getRootLogger().setLevel(defaultRootLevel) + } + } + this.initialized = false + } + + private def isLog4j12(): Boolean = { + // This distinguishes the log4j 1.2 binding, currently + // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently + // org.apache.logging.slf4j.Log4jLoggerFactory + val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr + "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index e5d60a7ef098..8f4c1b60920d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -28,7 +28,7 @@ private object ConfigHelpers { def toNumber[T](s: String, converter: String => T, key: String, configType: String): T = { try { - converter(s) + converter(s.trim) } catch { case _: NumberFormatException => throw new IllegalArgumentException(s"$key should be $configType, but was $s") @@ -37,7 +37,7 @@ private object ConfigHelpers { def toBoolean(s: String, key: String): Boolean = { try { - s.toBoolean + s.trim.toBoolean } catch { case _: IllegalArgumentException => throw new IllegalArgumentException(s"$key should be boolean, but was $s") @@ -126,8 +126,8 @@ private[spark] class TypedConfigBuilder[T]( /** Creates a [[ConfigEntry]] that does not have a default value. */ def createOptional: OptionalConfigEntry[T] = { - val entry = new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc, - parent._public) + val entry = new OptionalConfigEntry[T](parent.key, parent._alternatives, converter, + stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -140,8 +140,8 @@ private[spark] class TypedConfigBuilder[T]( createWithDefaultString(default.asInstanceOf[String]) } else { val transformedDefault = converter(stringConverter(default)) - val entry = new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, - stringConverter, parent._doc, parent._public) + val entry = new ConfigEntryWithDefault[T](parent.key, parent._alternatives, + transformedDefault, converter, stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -149,8 +149,8 @@ private[spark] class TypedConfigBuilder[T]( /** Creates a [[ConfigEntry]] with a function to determine the default value */ def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = { - val entry = new ConfigEntryWithDefaultFunction[T](parent.key, defaultFunc, converter, - stringConverter, parent._doc, parent._public) + val entry = new ConfigEntryWithDefaultFunction[T](parent.key, parent._alternatives, defaultFunc, + converter, stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_ (entry)) entry } @@ -160,8 +160,8 @@ private[spark] class TypedConfigBuilder[T]( * [[String]] and must be a valid value for the entry. */ def createWithDefaultString(default: String): ConfigEntry[T] = { - val entry = new ConfigEntryWithDefaultString[T](parent.key, default, converter, stringConverter, - parent._doc, parent._public) + val entry = new ConfigEntryWithDefaultString[T](parent.key, parent._alternatives, default, + converter, stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -180,6 +180,7 @@ private[spark] case class ConfigBuilder(key: String) { private[config] var _public = true private[config] var _doc = "" private[config] var _onCreate: Option[ConfigEntry[_] => Unit] = None + private[config] var _alternatives = List.empty[String] def internal(): ConfigBuilder = { _public = false @@ -200,6 +201,11 @@ private[spark] case class ConfigBuilder(key: String) { this } + def withAlternative(key: String): ConfigBuilder = { + _alternatives = _alternatives :+ key + this + } + def intConf: TypedConfigBuilder[Int] = { new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int")) } @@ -229,7 +235,7 @@ private[spark] case class ConfigBuilder(key: String) { } def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = { - new FallbackConfigEntry(key, _doc, _public, fallback) + new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback) } def regexConf: TypedConfigBuilder[Regex] = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index e86712e84d6a..f1190289244e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -41,6 +41,7 @@ package org.apache.spark.internal.config */ private[spark] abstract class ConfigEntry[T] ( val key: String, + val alternatives: List[String], val valueConverter: String => T, val stringConverter: T => String, val doc: String, @@ -52,6 +53,10 @@ private[spark] abstract class ConfigEntry[T] ( def defaultValueString: String + protected def readString(reader: ConfigReader): Option[String] = { + alternatives.foldLeft(reader.get(key))((res, nextKey) => res.orElse(reader.get(nextKey))) + } + def readFrom(reader: ConfigReader): T def defaultValue: Option[T] = None @@ -59,63 +64,64 @@ private[spark] abstract class ConfigEntry[T] ( override def toString: String = { s"ConfigEntry(key=$key, defaultValue=$defaultValueString, doc=$doc, public=$isPublic)" } - } private class ConfigEntryWithDefault[T] ( key: String, + alternatives: List[String], _defaultValue: T, valueConverter: String => T, stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { override def defaultValue: Option[T] = Some(_defaultValue) override def defaultValueString: String = stringConverter(_defaultValue) def readFrom(reader: ConfigReader): T = { - reader.get(key).map(valueConverter).getOrElse(_defaultValue) + readString(reader).map(valueConverter).getOrElse(_defaultValue) } } private class ConfigEntryWithDefaultFunction[T] ( key: String, + alternatives: List[String], _defaultFunction: () => T, valueConverter: String => T, stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { override def defaultValue: Option[T] = Some(_defaultFunction()) override def defaultValueString: String = stringConverter(_defaultFunction()) def readFrom(reader: ConfigReader): T = { - reader.get(key).map(valueConverter).getOrElse(_defaultFunction()) + readString(reader).map(valueConverter).getOrElse(_defaultFunction()) } } private class ConfigEntryWithDefaultString[T] ( key: String, + alternatives: List[String], _defaultValue: String, valueConverter: String => T, stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { override def defaultValue: Option[T] = Some(valueConverter(_defaultValue)) override def defaultValueString: String = _defaultValue def readFrom(reader: ConfigReader): T = { - val value = reader.get(key).getOrElse(reader.substitute(_defaultValue)) + val value = readString(reader).getOrElse(reader.substitute(_defaultValue)) valueConverter(value) } - } @@ -124,19 +130,20 @@ private class ConfigEntryWithDefaultString[T] ( */ private[spark] class OptionalConfigEntry[T]( key: String, + alternatives: List[String], val rawValueConverter: String => T, val rawStringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry[Option[T]](key, s => Some(rawValueConverter(s)), + extends ConfigEntry[Option[T]](key, alternatives, + s => Some(rawValueConverter(s)), v => v.map(rawStringConverter).orNull, doc, isPublic) { override def defaultValueString: String = "" override def readFrom(reader: ConfigReader): Option[T] = { - reader.get(key).map(rawValueConverter) + readString(reader).map(rawValueConverter) } - } /** @@ -144,17 +151,18 @@ private[spark] class OptionalConfigEntry[T]( */ private class FallbackConfigEntry[T] ( key: String, + alternatives: List[String], doc: String, isPublic: Boolean, private[config] val fallback: ConfigEntry[T]) - extends ConfigEntry[T](key, fallback.valueConverter, fallback.stringConverter, doc, isPublic) { + extends ConfigEntry[T](key, alternatives, + fallback.valueConverter, fallback.stringConverter, doc, isPublic) { override def defaultValueString: String = s"" override def readFrom(reader: ConfigReader): T = { - reader.get(key).map(valueConverter).getOrElse(fallback.readFrom(reader)) + readString(reader).map(valueConverter).getOrElse(fallback.readFrom(reader)) } - } private[spark] object ConfigEntry { diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala index 97f56a64d600..5d98a1185f05 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala @@ -47,28 +47,16 @@ private[spark] class MapProvider(conf: JMap[String, String]) extends ConfigProvi } /** - * A config provider that only reads Spark config keys, and considers default values for known - * configs when fetching configuration values. + * A config provider that only reads Spark config keys. */ private[spark] class SparkConfigProvider(conf: JMap[String, String]) extends ConfigProvider { - import ConfigEntry._ - override def get(key: String): Option[String] = { if (key.startsWith("spark.")) { - Option(conf.get(key)).orElse(defaultValueString(key)) + Option(conf.get(key)) } else { None } } - private def defaultValueString(key: String): Option[String] = { - findEntry(key) match { - case e: ConfigEntryWithDefault[_] => Option(e.defaultValueString) - case e: ConfigEntryWithDefaultString[_] => Option(e.defaultValueString) - case e: FallbackConfigEntry[_] => get(e.fallback.key) - case _ => None - } - } - } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala index c62de9bfd8fc..c1ab22150d02 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala @@ -92,7 +92,7 @@ private[spark] class ConfigReader(conf: ConfigProvider) { require(!usedRefs.contains(ref), s"Circular reference in $input: $ref") val replacement = bindings.get(prefix) - .flatMap(_.get(name)) + .flatMap(getOrDefault(_, name)) .map { v => substitute(v, usedRefs + ref) } .getOrElse(m.matched) Regex.quoteReplacement(replacement) @@ -102,4 +102,20 @@ private[spark] class ConfigReader(conf: ConfigProvider) { } } + /** + * Gets the value of a config from the given `ConfigProvider`. If no value is found for this + * config, and the `ConfigEntry` defines this config has default value, return the default value. + */ + private def getOrDefault(conf: ConfigProvider, key: String): Option[String] = { + conf.get(key).orElse { + ConfigEntry.findEntry(key) match { + case e: ConfigEntryWithDefault[_] => Option(e.defaultValueString) + case e: ConfigEntryWithDefaultString[_] => Option(e.defaultValueString) + case e: ConfigEntryWithDefaultFunction[_] => Option(e.defaultValueString) + case e: FallbackConfigEntry[_] => getOrDefault(conf, e.fallback.key) + case _ => None + } + } + } + } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7f7921d56f49..44a2815b81a7 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -87,7 +87,7 @@ package object config { .intConf .createOptional - private[spark] val PY_FILES = ConfigBuilder("spark.submit.pyFiles") + private[spark] val PY_FILES = ConfigBuilder("spark.yarn.dist.pyFiles") .internal() .stringConf .toSequence @@ -149,13 +149,34 @@ package object config { .internal() .timeConf(TimeUnit.MILLISECONDS) .createOptional + + private[spark] val BLACKLIST_FETCH_FAILURE_ENABLED = + ConfigBuilder("spark.blacklist.application.fetchFailure.enabled") + .booleanConf + .createWithDefault(false) // End blacklist confs - private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = - ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size") + private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = + ConfigBuilder("spark.files.fetchFailure.unRegisterOutputOnHost") + .doc("Whether to un-register all the outputs on the host in condition that we receive " + + " a FetchFailure. This is set default to false, which means, we only un-register the " + + " outputs related to the exact executor(instead of the host) on a FetchFailure.") + .booleanConf + .createWithDefault(false) + + private[spark] val LISTENER_BUS_EVENT_QUEUE_CAPACITY = + ConfigBuilder("spark.scheduler.listenerbus.eventqueue.capacity") + .withAlternative("spark.scheduler.listenerbus.eventqueue.size") .intConf + .checkValue(_ > 0, "The capacity of listener bus event queue must not be negative") .createWithDefault(10000) + private[spark] val LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED = + ConfigBuilder("spark.scheduler.listenerbus.metrics.maxListenerClassesTimed") + .internal() + .intConf + .createWithDefault(128) + // This property sets the root namespace for metrics reporting private[spark] val METRICS_NAMESPACE = ConfigBuilder("spark.metrics.namespace") .stringConf @@ -201,7 +222,7 @@ package object config { private[spark] val DRIVER_HOST_ADDRESS = ConfigBuilder("spark.driver.host") .doc("Address of driver endpoints.") .stringConf - .createWithDefault(Utils.localHostName()) + .createWithDefault(Utils.localCanonicalHostName()) private[spark] val DRIVER_BIND_ADDRESS = ConfigBuilder("spark.driver.bindAddress") .doc("Address where to bind network listen sockets on the driver.") @@ -272,10 +293,121 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val BUFFER_WRITE_CHUNK_SIZE = + ConfigBuilder("spark.buffer.write.chunkSize") + .internal() + .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" + + " ChunkedByteBuffer should not larger than Int.MaxValue.") + .createWithDefault(64 * 1024 * 1024) + private[spark] val CHECKPOINT_COMPRESS = ConfigBuilder("spark.checkpoint.compress") .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " + "spark.io.compression.codec.") .booleanConf .createWithDefault(false) + + private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = + ConfigBuilder("spark.shuffle.accurateBlockThreshold") + .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + + "record the size accurately if it's above this config. This helps to prevent OOM by " + + "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(100 * 1024 * 1024) + + private[spark] val SHUFFLE_REGISTRATION_TIMEOUT = + ConfigBuilder("spark.shuffle.registration.timeout") + .doc("Timeout in milliseconds for registration to the external shuffle service.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(5000) + + private[spark] val SHUFFLE_REGISTRATION_MAX_ATTEMPTS = + ConfigBuilder("spark.shuffle.registration.maxAttempts") + .doc("When we fail to register to the external shuffle service, we will " + + "retry for maxAttempts times.") + .intConf + .createWithDefault(3) + + private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = + ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") + .doc("This configuration limits the number of remote blocks being fetched per reduce task" + + " from a given host port. When a large number of blocks are being requested from a given" + + " address in a single fetch or simultaneously, this could crash the serving executor or" + + " Node Manager. This is especially useful to reduce the load on the Node Manager when" + + " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") + .intConf + .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") + .createWithDefault(Int.MaxValue) + + private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = + ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") + .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + + "above this threshold. This is to avoid a giant request takes too much memory. We can " + + "enable this config by setting a specific value(e.g. 200m). Note that this config can " + + "be enabled only when the shuffle shuffle service is newer than Spark-2.2 or the shuffle" + + " service is disabled.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Long.MaxValue) + + private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = + ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") + .doc("Enable tracking of updatedBlockStatuses in the TaskMetrics. Off by default since " + + "tracking the block statuses can use a lot of memory and its not used anywhere within " + + "spark.") + .booleanConf + .createWithDefault(false) + + private[spark] val SHUFFLE_FILE_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.file.buffer") + .doc("Size of the in-memory buffer for each shuffle file output stream. " + + "These buffers reduce the number of disk seeks and system calls made " + + "in creating intermediate shuffle files.") + .bytesConf(ByteUnit.KiB) + .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, + s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") + .createWithDefaultString("32k") + + private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") + .doc("The file system for this buffer size after each partition " + + "is written in unsafe shuffle writer.") + .bytesConf(ByteUnit.KiB) + .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, + s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") + .createWithDefaultString("32k") + + private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") + .doc("The buffer size to use when writing the sorted records to an on-disk file.") + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v > 0 && v <= Int.MaxValue, + s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") + .createWithDefault(1024 * 1024) + + private[spark] val UNROLL_MEMORY_CHECK_PERIOD = + ConfigBuilder("spark.storage.unrollMemoryCheckPeriod") + .internal() + .doc("The memory check period is used to determine how often we should check whether " + + "there is a need to request more memory when we try to unroll the given block in memory.") + .longConf + .createWithDefault(16) + + private[spark] val UNROLL_MEMORY_GROWTH_FACTOR = + ConfigBuilder("spark.storage.unrollMemoryGrowthFactor") + .internal() + .doc("Memory to request as a multiple of the size that used to unroll the block.") + .doubleConf + .createWithDefault(1.5) + + private[spark] val FORCE_DOWNLOAD_SCHEMES = + ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") + .doc("Comma-separated list of schemes for which files will be downloaded to the " + + "local disk prior to being added to YARN's distributed cache. For use in cases " + + "where the YARN service does not support schemes that are supported by Spark, like http, " + + "https and ftp.") + .stringConf + .toSequence + .createWithDefault(Nil) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 7efa9416362a..50f51e1af453 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -24,12 +24,12 @@ import org.apache.spark.util.Utils /** - * An interface to define how a single Spark job commits its outputs. Two notes: + * An interface to define how a single Spark job commits its outputs. Three notes: * * 1. Implementations must be serializable, as the committer instance instantiated on the driver * will be used for tasks on executors. - * 2. Implementations should have a constructor with either 2 or 3 arguments: - * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean). + * 2. Implementations should have a constructor with 2 arguments: + * (jobId: String, path: String) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -139,19 +139,10 @@ object FileCommitProtocol { /** * Instantiates a FileCommitProtocol using the given className. */ - def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean) + def instantiate(className: String, jobId: String, outputPath: String) : FileCommitProtocol = { val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - - // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean). - // If that doesn't exist, try the one with (jobId: string, outputPath: String). - try { - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) - ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean]) - } catch { - case _: NoSuchMethodException => - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) - ctor.newInstance(jobId, outputPath) - } + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala new file mode 100644 index 000000000000..ddbd624b380d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => NewTaskAttemptContext} + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapRedCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = { + val config = context.getConfiguration.asInstanceOf[JobConf] + config.getOutputCommitter + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 22e26799138b..b1d07ab2c919 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -73,7 +73,8 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val stagingDir: String = committer match { // For FileOutputCommitter it has its own staging path called "work path". - case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case f: FileOutputCommitter => + Option(f.getWorkPath).map(_.toString).getOrElse(path) case _ => path } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala new file mode 100644 index 000000000000..9b987e0e1bb6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import scala.reflect.ClassTag + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.SparkConf + +/** + * Interface for create output format/committer/writer used during saving an RDD using a Hadoop + * OutputFormat (both from the old mapred API and the new mapreduce API) + * + * Notes: + * 1. Implementations should throw [[IllegalArgumentException]] when wrong hadoop API is + * referenced; + * 2. Implementations must be serializable, as the instance instantiated on the driver + * will be used for tasks on executors; + * 3. Implementations should have a constructor with exactly one argument: + * (conf: SerializableConfiguration) or (conf: SerializableJobConf). + */ +abstract class HadoopWriteConfigUtil[K, V: ClassTag] extends Serializable { + + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- + + def createJobContext(jobTrackerId: String, jobId: Int): JobContext + + def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): TaskAttemptContext + + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- + + def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol + + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + def initWriter(taskContext: TaskAttemptContext, splitId: Int): Unit + + def write(pair: (K, V)): Unit + + def closeWriter(taskContext: TaskAttemptContext): Unit + + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- + + def initOutputFormat(jobContext: JobContext): Unit + + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- + + def assertConf(jobContext: JobContext, conf: SparkConf): Unit +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala deleted file mode 100644 index 376ff9bb19f7..000000000000 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.internal.io - -import java.text.SimpleDateFormat -import java.util.{Date, Locale} - -import scala.reflect.ClassTag -import scala.util.DynamicVariable - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.{JobConf, JobID} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl - -import org.apache.spark.{SparkConf, SparkException, TaskContext} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.OutputMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} - -/** - * A helper object that saves an RDD using a Hadoop OutputFormat - * (from the newer mapreduce API, not the old mapred API). - */ -private[spark] -object SparkHadoopMapReduceWriter extends Logging { - - /** - * Basic work flow of this command is: - * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to - * be issued. - * 2. Issues a write job consists of one or more executor side tasks, each of which writes all - * rows within an RDD partition. - * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any - * exception is thrown during task commitment, also aborts that task. - * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is - * thrown during job commitment, also aborts the job. - */ - def write[K, V: ClassTag]( - rdd: RDD[(K, V)], - hadoopConf: Configuration): Unit = { - // Extract context and configuration from RDD. - val sparkContext = rdd.context - val stageId = rdd.id - val sparkConf = rdd.conf - val conf = new SerializableConfiguration(hadoopConf) - - // Set up a job. - val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date()) - val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0) - val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId) - val format = jobContext.getOutputFormatClass - - if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) { - // FileOutputFormat ignores the filesystem parameter - val jobFormat = format.newInstance - jobFormat.checkOutputSpecs(jobContext) - } - - val committer = FileCommitProtocol.instantiate( - className = classOf[HadoopMapReduceCommitProtocol].getName, - jobId = stageId.toString, - outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"), - isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] - committer.setupJob(jobContext) - - // Try to write all RDD partitions as a Hadoop OutputFormat. - try { - val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { - executeTask( - context = context, - jobTrackerId = jobTrackerId, - sparkStageId = context.stageId, - sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, - committer = committer, - hadoopConf = conf.value, - outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]], - iterator = iter) - }) - - committer.commitJob(jobContext, ret) - logInfo(s"Job ${jobContext.getJobID} committed.") - } catch { - case cause: Throwable => - logError(s"Aborting job ${jobContext.getJobID}.", cause) - committer.abortJob(jobContext) - throw new SparkException("Job aborted.", cause) - } - } - - /** Write an RDD partition out in a single Spark task. */ - private def executeTask[K, V: ClassTag]( - context: TaskContext, - jobTrackerId: String, - sparkStageId: Int, - sparkPartitionId: Int, - sparkAttemptNumber: Int, - committer: FileCommitProtocol, - hadoopConf: Configuration, - outputFormat: Class[_ <: OutputFormat[K, V]], - iterator: Iterator[(K, V)]): TaskCommitMessage = { - // Set up a task. - val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE, - sparkPartitionId, sparkAttemptNumber) - val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId) - committer.setupTask(taskContext) - - val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) - - // Initiate the writer. - val taskFormat = outputFormat.newInstance() - // If OutputFormat is Configurable, we should set conf to it. - taskFormat match { - case c: Configurable => c.setConf(hadoopConf) - case _ => () - } - var writer = taskFormat.getRecordWriter(taskContext) - .asInstanceOf[RecordWriter[K, V]] - require(writer != null, "Unable to obtain RecordWriter") - var recordsWritten = 0L - - // Write all rows in RDD partition. - try { - val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { - // Write rows out, release resource and commit the task. - while (iterator.hasNext) { - val pair = iterator.next() - writer.write(pair._1, pair._2) - - // Update bytes written metric every few records - SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) - recordsWritten += 1 - } - if (writer != null) { - writer.close(taskContext) - writer = null - } - committer.commitTask(taskContext) - }(catchBlock = { - // If there is an error, release resource and then abort the task. - try { - if (writer != null) { - writer.close(taskContext) - writer = null - } - } finally { - committer.abortTask(taskContext) - logError(s"Task ${taskContext.getTaskAttemptID} aborted.") - } - }) - - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - - ret - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index acc9c3857100..949d8c677998 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -17,143 +17,374 @@ package org.apache.spark.internal.io -import java.io.IOException -import java.text.{NumberFormat, SimpleDateFormat} +import java.text.NumberFormat import java.util.{Date, Locale} +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ -import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, +TaskAttemptContext => NewTaskAttemptContext, TaskAttemptID => NewTaskAttemptID, TaskType} +import org.apache.hadoop.mapreduce.task.{TaskAttemptContextImpl => NewTaskAttemptContextImpl} -import org.apache.spark.SerializableWritable +import org.apache.spark.{SerializableWritable, SparkConf, SparkException, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.rdd.HadoopRDD -import org.apache.spark.util.SerializableJobConf +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} /** - * Internal helper class that saves an RDD using a Hadoop OutputFormat. - * - * Saves the RDD using a JobConf, which should contain an output key class, an output value class, - * a filename to write to, etc, exactly like in a Hadoop MapReduce job. + * A helper object that saves an RDD using a Hadoop OutputFormat. + */ +private[spark] +object SparkHadoopWriter extends Logging { + import SparkHadoopWriterUtils._ + + /** + * Basic work flow of this command is: + * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to + * be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ + def write[K, V: ClassTag]( + rdd: RDD[(K, V)], + config: HadoopWriteConfigUtil[K, V]): Unit = { + // Extract context and configuration from RDD. + val sparkContext = rdd.context + val stageId = rdd.id + + // Set up a job. + val jobTrackerId = createJobTrackerID(new Date()) + val jobContext = config.createJobContext(jobTrackerId, stageId) + config.initOutputFormat(jobContext) + + // Assert the output format/key/value class is set in JobConf. + config.assertConf(jobContext, rdd.conf) + + val committer = config.createCommitter(stageId) + committer.setupJob(jobContext) + + // Try to write all RDD partitions as a Hadoop OutputFormat. + try { + val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + executeTask( + context = context, + config = config, + jobTrackerId = jobTrackerId, + sparkStageId = context.stageId, + sparkPartitionId = context.partitionId, + sparkAttemptNumber = context.attemptNumber, + committer = committer, + iterator = iter) + }) + + committer.commitJob(jobContext, ret) + logInfo(s"Job ${jobContext.getJobID} committed.") + } catch { + case cause: Throwable => + logError(s"Aborting job ${jobContext.getJobID}.", cause) + committer.abortJob(jobContext) + throw new SparkException("Job aborted.", cause) + } + } + + /** Write a RDD partition out in a single Spark task. */ + private def executeTask[K, V: ClassTag]( + context: TaskContext, + config: HadoopWriteConfigUtil[K, V], + jobTrackerId: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + iterator: Iterator[(K, V)]): TaskCommitMessage = { + // Set up a task. + val taskContext = config.createTaskAttemptContext( + jobTrackerId, sparkStageId, sparkPartitionId, sparkAttemptNumber) + committer.setupTask(taskContext) + + val (outputMetrics, callback) = initHadoopOutputMetrics(context) + + // Initiate the writer. + config.initWriter(taskContext, sparkPartitionId) + var recordsWritten = 0L + + // Write all rows in RDD partition. + try { + val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val pair = iterator.next() + config.write(pair) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) + recordsWritten += 1 + } + + config.closeWriter(taskContext) + committer.commitTask(taskContext) + }(catchBlock = { + // If there is an error, release resource and then abort the task. + try { + config.closeWriter(taskContext) + } finally { + committer.abortTask(taskContext) + logError(s"Task ${taskContext.getTaskAttemptID} aborted.") + } + }) + + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + + ret + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } + } +} + +/** + * A helper class that reads JobConf from older mapred API, creates output Format/Committer/Writer. */ private[spark] -class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { +class HadoopMapRedWriteConfigUtil[K, V: ClassTag](conf: SerializableJobConf) + extends HadoopWriteConfigUtil[K, V] with Logging { - private val now = new Date() - private val conf = new SerializableJobConf(jobConf) + private var outputFormat: Class[_ <: OutputFormat[K, V]] = null + private var writer: RecordWriter[K, V] = null - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null + private def getConf: JobConf = conf.value - @transient private var writer: RecordWriter[AnyRef, AnyRef] = null - @transient private var format: OutputFormat[AnyRef, AnyRef] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- - def preSetup() { - setIDs(0, 0, 0) - HadoopRDD.addLocalConfiguration("", 0, 0, 0, conf.value) + override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = { + val jobAttemptId = new SerializableWritable(new JobID(jobTrackerId, jobId)) + new JobContextImpl(getConf, jobAttemptId.value) + } - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) + override def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): NewTaskAttemptContext = { + // Update JobConf. + HadoopRDD.addLocalConfiguration(jobTrackerId, jobId, splitId, taskAttemptId, conf.value) + // Create taskContext. + val attemptId = new TaskAttemptID(jobTrackerId, jobId, TaskType.MAP, splitId, taskAttemptId) + new TaskAttemptContextImpl(getConf, attemptId) } + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), - jobid, splitID, attemptID, conf.value) + override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = { + // Update JobConf. + HadoopRDD.addLocalConfiguration("", 0, 0, 0, getConf) + // Create commit protocol. + FileCommitProtocol.instantiate( + className = classOf[HadoopMapRedCommitProtocol].getName, + jobId = jobId.toString, + outputPath = getConf.get("mapred.output.dir") + ).asInstanceOf[HadoopMapReduceCommitProtocol] } - def open() { + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = { val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) - val outputName = "part-" + numfmt.format(splitID) - val path = FileOutputFormat.getOutputPath(conf.value) + val outputName = "part-" + numfmt.format(splitId) + val path = FileOutputFormat.getOutputPath(getConf) val fs: FileSystem = { if (path != null) { - path.getFileSystem(conf.value) + path.getFileSystem(getConf) } else { - FileSystem.get(conf.value) + FileSystem.get(getConf) } } - getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) + writer = getConf.getOutputFormat + .getRecordWriter(fs, getConf, outputName, Reporter.NULL) + .asInstanceOf[RecordWriter[K, V]] + + require(writer != null, "Unable to obtain RecordWriter") } - def write(key: AnyRef, value: AnyRef) { + override def write(pair: (K, V)): Unit = { + require(writer != null, "Must call createWriter before write.") + writer.write(pair._1, pair._2) + } + + override def closeWriter(taskContext: NewTaskAttemptContext): Unit = { if (writer != null) { - writer.write(key, value) - } else { - throw new IOException("Writer is null, open() has not been called") + writer.close(Reporter.NULL) } } - def close() { - writer.close(Reporter.NULL) - } + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- - def commit() { - SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) + override def initOutputFormat(jobContext: NewJobContext): Unit = { + if (outputFormat == null) { + outputFormat = getConf.getOutputFormat.getClass + .asInstanceOf[Class[_ <: OutputFormat[K, V]]] + } } - def commitJob() { - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) + private def getOutputFormat(): OutputFormat[K, V] = { + require(outputFormat != null, "Must call initOutputFormat first.") + + outputFormat.newInstance() } - // ********* Private Functions ********* + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- + + override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = { + val outputFormatInstance = getOutputFormat() + val keyClass = getConf.getOutputKeyClass + val valueClass = getConf.getOutputValueClass + if (outputFormatInstance == null) { + throw new SparkException("Output format class not set") + } + if (keyClass == null) { + throw new SparkException("Output key class not set") + } + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + SparkHadoopUtil.get.addCredentials(getConf) + + logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + + valueClass.getSimpleName + ")") - private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef, AnyRef]] + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) { + // FileOutputFormat ignores the filesystem parameter + val ignoredFs = FileSystem.get(getConf) + getOutputFormat().checkOutputSpecs(ignoredFs, getConf) } - format + } +} + +/** + * A helper class that reads Configuration from newer mapreduce API, creates output + * Format/Committer/Writer. + */ +private[spark] +class HadoopMapReduceWriteConfigUtil[K, V: ClassTag](conf: SerializableConfiguration) + extends HadoopWriteConfigUtil[K, V] with Logging { + + private var outputFormat: Class[_ <: NewOutputFormat[K, V]] = null + private var writer: NewRecordWriter[K, V] = null + + private def getConf: Configuration = conf.value + + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- + + override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = { + val jobAttemptId = new NewTaskAttemptID(jobTrackerId, jobId, TaskType.MAP, 0, 0) + new NewTaskAttemptContextImpl(getConf, jobAttemptId) + } + + override def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): NewTaskAttemptContext = { + val attemptId = new NewTaskAttemptID( + jobTrackerId, jobId, TaskType.REDUCE, splitId, taskAttemptId) + new NewTaskAttemptContextImpl(getConf, attemptId) + } + + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- + + override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = { + FileCommitProtocol.instantiate( + className = classOf[HadoopMapReduceCommitProtocol].getName, + jobId = jobId.toString, + outputPath = getConf.get("mapreduce.output.fileoutputformat.outputdir") + ).asInstanceOf[HadoopMapReduceCommitProtocol] } - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = { + val taskFormat = getOutputFormat() + // If OutputFormat is Configurable, we should set conf to it. + taskFormat match { + case c: Configurable => c.setConf(getConf) + case _ => () } - committer + + writer = taskFormat.getRecordWriter(taskContext) + .asInstanceOf[NewRecordWriter[K, V]] + + require(writer != null, "Unable to obtain RecordWriter") + } + + override def write(pair: (K, V)): Unit = { + require(writer != null, "Must call createWriter before write.") + writer.write(pair._1, pair._2) } - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = new JobContextImpl(conf.value, jID.value) + override def closeWriter(taskContext: NewTaskAttemptContext): Unit = { + if (writer != null) { + writer.close(taskContext) + writer = null + } else { + logWarning("Writer has been closed.") } - jobContext } - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- + + override def initOutputFormat(jobContext: NewJobContext): Unit = { + if (outputFormat == null) { + outputFormat = jobContext.getOutputFormatClass + .asInstanceOf[Class[_ <: NewOutputFormat[K, V]]] } - taskContext } - protected def newTaskAttemptContext( - conf: JobConf, - attemptId: TaskAttemptID): TaskAttemptContext = { - new TaskAttemptContextImpl(conf, attemptId) + private def getOutputFormat(): NewOutputFormat[K, V] = { + require(outputFormat != null, "Must call initOutputFormat first.") + + outputFormat.newInstance() } - private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { - jobID = jobid - splitID = splitid - attemptID = attemptid + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- - jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) + override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = { + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) { + getOutputFormat().checkOutputSpecs(jobContext) + } } } diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0cb16f0627b7..27f2e429395d 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -21,7 +21,7 @@ import java.io._ import java.util.Locale import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} -import net.jpountz.lz4.LZ4BlockOutputStream +import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf @@ -115,7 +115,10 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { new LZ4BlockOutputStream(s, blockSize) } - override def compressedInputStream(s: InputStream): InputStream = new LZ4BlockInputStream(s) + override def compressedInputStream(s: InputStream): InputStream = { + val disableConcatenationOfByteStream = false + new LZ4BlockInputStream(s, disableConcatenationOfByteStream) + } } diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index fea2808218a5..78edd2c4d7fa 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -143,7 +143,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( } executionPool.acquireMemory( - numBytes, taskAttemptId, maybeGrowExecutionPool, computeMaxExecutionPoolSize) + numBytes, taskAttemptId, maybeGrowExecutionPool, () => computeMaxExecutionPoolSize) } override def acquireStorageMemory( @@ -160,7 +160,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( case MemoryMode.OFF_HEAP => ( offHeapExecutionMemoryPool, offHeapStorageMemoryPool, - maxOffHeapMemory) + maxOffHeapStorageMemory) } if (numBytes > maxMemory) { // Fail fast if the block simply won't fit @@ -171,7 +171,8 @@ private[spark] class UnifiedMemoryManager private[memory] ( if (numBytes > storagePool.memoryFree) { // There is not enough free memory in the storage pool, so try to borrow free memory from // the execution pool. - val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, numBytes) + val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, + numBytes - storagePool.memoryFree) executionPool.decrementPoolSize(memoryBorrowedFromExecution) storagePool.incrementPoolSize(memoryBorrowedFromExecution) } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 1d494500cdb5..3457a2632277 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -26,8 +26,8 @@ import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.{Source, StaticSources} import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 23e31823f493..ac33e68abb49 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -68,8 +68,8 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { - case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) - case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) + case Some("udp") => new GraphiteUDP(host, port) + case Some("tcp") | None => new Graphite(host, port) case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/StatsdReporter.scala b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdReporter.scala new file mode 100644 index 000000000000..ba75aa1c65cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdReporter.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.io.IOException +import java.net.{DatagramPacket, DatagramSocket, InetSocketAddress} +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.SortedMap +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + +import com.codahale.metrics._ +import org.apache.hadoop.net.NetUtils + +import org.apache.spark.internal.Logging + +/** + * @see + * StatsD metric types + */ +private[spark] object StatsdMetricType { + val COUNTER = "c" + val GAUGE = "g" + val TIMER = "ms" + val Set = "s" +} + +private[spark] class StatsdReporter( + registry: MetricRegistry, + host: String = "127.0.0.1", + port: Int = 8125, + prefix: String = "", + filter: MetricFilter = MetricFilter.ALL, + rateUnit: TimeUnit = TimeUnit.SECONDS, + durationUnit: TimeUnit = TimeUnit.MILLISECONDS) + extends ScheduledReporter(registry, "statsd-reporter", filter, rateUnit, durationUnit) + with Logging { + + import StatsdMetricType._ + + private val address = new InetSocketAddress(host, port) + private val whitespace = "[\\s]+".r + + override def report( + gauges: SortedMap[String, Gauge[_]], + counters: SortedMap[String, Counter], + histograms: SortedMap[String, Histogram], + meters: SortedMap[String, Meter], + timers: SortedMap[String, Timer]): Unit = + Try(new DatagramSocket) match { + case Failure(ioe: IOException) => logWarning("StatsD datagram socket construction failed", + NetUtils.wrapException(host, port, NetUtils.getHostname(), 0, ioe)) + case Failure(e) => logWarning("StatsD datagram socket construction failed", e) + case Success(s) => + implicit val socket = s + val localAddress = Try(socket.getLocalAddress).map(_.getHostAddress).getOrElse(null) + val localPort = socket.getLocalPort + Try { + gauges.entrySet.asScala.foreach(e => reportGauge(e.getKey, e.getValue)) + counters.entrySet.asScala.foreach(e => reportCounter(e.getKey, e.getValue)) + histograms.entrySet.asScala.foreach(e => reportHistogram(e.getKey, e.getValue)) + meters.entrySet.asScala.foreach(e => reportMetered(e.getKey, e.getValue)) + timers.entrySet.asScala.foreach(e => reportTimer(e.getKey, e.getValue)) + } recover { + case ioe: IOException => + logDebug(s"Unable to send packets to StatsD", NetUtils.wrapException( + address.getHostString, address.getPort, localAddress, localPort, ioe)) + case e: Throwable => logDebug(s"Unable to send packets to StatsD at '$host:$port'", e) + } + Try(socket.close()) recover { + case ioe: IOException => + logDebug("Error when close socket to StatsD", NetUtils.wrapException( + address.getHostString, address.getPort, localAddress, localPort, ioe)) + case e: Throwable => logDebug("Error when close socket to StatsD", e) + } + } + + private def reportGauge(name: String, gauge: Gauge[_])(implicit socket: DatagramSocket): Unit = + formatAny(gauge.getValue).foreach(v => send(fullName(name), v, GAUGE)) + + private def reportCounter(name: String, counter: Counter)(implicit socket: DatagramSocket): Unit = + send(fullName(name), format(counter.getCount), COUNTER) + + private def reportHistogram(name: String, histogram: Histogram) + (implicit socket: DatagramSocket): Unit = { + val snapshot = histogram.getSnapshot + send(fullName(name, "count"), format(histogram.getCount), GAUGE) + send(fullName(name, "max"), format(snapshot.getMax), TIMER) + send(fullName(name, "mean"), format(snapshot.getMean), TIMER) + send(fullName(name, "min"), format(snapshot.getMin), TIMER) + send(fullName(name, "stddev"), format(snapshot.getStdDev), TIMER) + send(fullName(name, "p50"), format(snapshot.getMedian), TIMER) + send(fullName(name, "p75"), format(snapshot.get75thPercentile), TIMER) + send(fullName(name, "p95"), format(snapshot.get95thPercentile), TIMER) + send(fullName(name, "p98"), format(snapshot.get98thPercentile), TIMER) + send(fullName(name, "p99"), format(snapshot.get99thPercentile), TIMER) + send(fullName(name, "p999"), format(snapshot.get999thPercentile), TIMER) + } + + private def reportMetered(name: String, meter: Metered)(implicit socket: DatagramSocket): Unit = { + send(fullName(name, "count"), format(meter.getCount), GAUGE) + send(fullName(name, "m1_rate"), format(convertRate(meter.getOneMinuteRate)), TIMER) + send(fullName(name, "m5_rate"), format(convertRate(meter.getFiveMinuteRate)), TIMER) + send(fullName(name, "m15_rate"), format(convertRate(meter.getFifteenMinuteRate)), TIMER) + send(fullName(name, "mean_rate"), format(convertRate(meter.getMeanRate)), TIMER) + } + + private def reportTimer(name: String, timer: Timer)(implicit socket: DatagramSocket): Unit = { + val snapshot = timer.getSnapshot + send(fullName(name, "max"), format(convertDuration(snapshot.getMax)), TIMER) + send(fullName(name, "mean"), format(convertDuration(snapshot.getMean)), TIMER) + send(fullName(name, "min"), format(convertDuration(snapshot.getMin)), TIMER) + send(fullName(name, "stddev"), format(convertDuration(snapshot.getStdDev)), TIMER) + send(fullName(name, "p50"), format(convertDuration(snapshot.getMedian)), TIMER) + send(fullName(name, "p75"), format(convertDuration(snapshot.get75thPercentile)), TIMER) + send(fullName(name, "p95"), format(convertDuration(snapshot.get95thPercentile)), TIMER) + send(fullName(name, "p98"), format(convertDuration(snapshot.get98thPercentile)), TIMER) + send(fullName(name, "p99"), format(convertDuration(snapshot.get99thPercentile)), TIMER) + send(fullName(name, "p999"), format(convertDuration(snapshot.get999thPercentile)), TIMER) + + reportMetered(name, timer) + } + + private def send(name: String, value: String, metricType: String) + (implicit socket: DatagramSocket): Unit = { + val bytes = sanitize(s"$name:$value|$metricType").getBytes(UTF_8) + val packet = new DatagramPacket(bytes, bytes.length, address) + socket.send(packet) + } + + private def fullName(names: String*): String = MetricRegistry.name(prefix, names : _*) + + private def sanitize(s: String): String = whitespace.replaceAllIn(s, "-") + + private def format(v: Any): String = formatAny(v).getOrElse("") + + private def formatAny(v: Any): Option[String] = + v match { + case f: Float => Some("%2.2f".format(f)) + case d: Double => Some("%2.2f".format(d)) + case b: BigDecimal => Some("%2.2f".format(b)) + case n: Number => Some(v.toString) + case _ => None + } +} + diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala new file mode 100644 index 000000000000..859a2f6bcd45 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.util.Properties +import java.util.concurrent.TimeUnit + +import com.codahale.metrics.MetricRegistry + +import org.apache.spark.SecurityManager +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.MetricsSystem + +private[spark] object StatsdSink { + val STATSD_KEY_HOST = "host" + val STATSD_KEY_PORT = "port" + val STATSD_KEY_PERIOD = "period" + val STATSD_KEY_UNIT = "unit" + val STATSD_KEY_PREFIX = "prefix" + + val STATSD_DEFAULT_HOST = "127.0.0.1" + val STATSD_DEFAULT_PORT = "8125" + val STATSD_DEFAULT_PERIOD = "10" + val STATSD_DEFAULT_UNIT = "SECONDS" + val STATSD_DEFAULT_PREFIX = "" +} + +private[spark] class StatsdSink( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink with Logging { + import StatsdSink._ + + val host = property.getProperty(STATSD_KEY_HOST, STATSD_DEFAULT_HOST) + val port = property.getProperty(STATSD_KEY_PORT, STATSD_DEFAULT_PORT).toInt + + val pollPeriod = property.getProperty(STATSD_KEY_PERIOD, STATSD_DEFAULT_PERIOD).toInt + val pollUnit = + TimeUnit.valueOf(property.getProperty(STATSD_KEY_UNIT, STATSD_DEFAULT_UNIT).toUpperCase) + + val prefix = property.getProperty(STATSD_KEY_PREFIX, STATSD_DEFAULT_PREFIX) + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val reporter = new StatsdReporter(registry, host, port, prefix) + + override def start(): Unit = { + reporter.start(pollPeriod, pollUnit) + logInfo(s"StatsdSink started with prefix: '$prefix'") + } + + override def stop(): Unit = { + reporter.stop() + logInfo("StatsdSink stopped.") + } + + override def report(): Unit = reporter.report() +} + diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 8f83668d7902..b3f8bfe8b1d4 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -46,5 +46,5 @@ trait BlockDataManager { /** * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. */ - def releaseLock(blockId: BlockId): Unit + def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index cb9d389dd7ea..fe5fd2da039b 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -67,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit + listener: BlockFetchingListener, + tempShuffleFileManager: TempShuffleFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -100,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }) + }, tempShuffleFileManager = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 305fd9a6de10..eb4cf94164fd 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.BlockDataManager -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer.NioManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b75e91b66096..ac4d85004bad 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -18,18 +18,21 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer +import java.util.{HashMap => JHashMap, Map => JMap} import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag +import com.codahale.metrics.{Metric, MetricSet} + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -83,18 +86,33 @@ private[spark] class NettyBlockTransferService( Utils.startServiceOnPort(_port, startService, conf, getClass.getName)._1 } + override def shuffleMetrics(): MetricSet = { + require(server != null && clientFactory != null, "NettyBlockTransferServer is not initialized") + + new MetricSet { + val allMetrics = new JHashMap[String, Metric]() + override def getMetrics: JMap[String, Metric] = { + allMetrics.putAll(clientFactory.getAllMetrics.getMetrics) + allMetrics.putAll(server.getAllMetrics.getMetrics) + allMetrics + } + } + } + override def fetchBlocks( host: String, port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit = { + listener: BlockFetchingListener, + tempShuffleFileManager: TempShuffleFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, + transportConf, tempShuffleFileManager).start() } } diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 2610d6f6e45a..8058a4d5dbde 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -17,6 +17,8 @@ package org.apache +import java.util.Properties + /** * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, @@ -40,9 +42,6 @@ package org.apache * Developer API are intended for advanced users want to extend Spark through lower * level interfaces. These are subject to changes or removal in minor releases. */ - -import java.util.Properties - package object spark { private object SparkBuildInfo { @@ -57,6 +56,9 @@ package object spark { val resourceStream = Thread.currentThread().getContextClassLoader. getResourceAsStream("spark-version-info.properties") + if (resourceStream == null) { + throw new SparkException("Could not find spark-version-info.properties") + } try { val unknownProp = "" @@ -71,8 +73,6 @@ package object spark { props.getProperty("date", unknownProp) ) } catch { - case npe: NullPointerException => - throw new SparkException("Error while locating file spark-version-info.properties", npe) case e: Exception => throw new SparkException("Error loading properties from spark-version-info.properties", e) } finally { diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala index 5a5bd7fbbe2f..cbee13687101 100644 --- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{PascalDistribution, PoissonDistribution} +import org.apache.commons.math3.distribution.PoissonDistribution /** * An ApproximateEvaluator for counts. @@ -48,22 +48,11 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) private[partial] object CountEvaluator { def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = { - // Let the total count be N. A fraction p has been counted already, with sum 'sum', - // as if each element from the total data set had been seen with probability p. - val dist = - if (sum <= 10000) { - // The remaining count, k=N-sum, may be modeled as negative binomial (aka Pascal), - // where there have been 'sum' successes of probability p already. (There are several - // conventions, but this is the one followed by Commons Math3.) - new PascalDistribution(sum.toInt, p) - } else { - // For large 'sum' (certainly, > Int.MaxValue!), use a Poisson approximation, which has - // a different interpretation. "sum" elements have been observed having scanned a fraction - // p of the data. This suggests data is counted at a rate of sum / p across the whole data - // set. The total expected count from the rest is distributed as - // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) - new PoissonDistribution(sum * (1 - p) / p) - } + // "sum" elements have been observed having scanned a fraction + // p of the data. This suggests data is counted at a rate of sum / p across the whole data + // set. The total expected count from the rest is distributed as + // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) + val dist = new PoissonDistribution(sum * (1 - p) / p) // Not quite symmetric; calculate interval straight from discrete distribution val low = dist.inverseCumulativeProbability((1 - confidence) / 2) val high = dist.inverseCumulativeProbability((1 + confidence) / 2) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 50d977a92da5..a14bad47dfe1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} @@ -35,8 +36,12 @@ private[spark] class BinaryFileRDD[T]( extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance val conf = getConf + // setMinPartitions below will call FileInputFormat.listStatus(), which can be quite slow when + // traversing a large number of directories and files. Parallelize it. + conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, + Runtime.getRuntime.availableProcessors().toString) + val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index a091f06b4ed7..4574c3724962 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -26,8 +26,8 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer -import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} /** * The references to rdd and splitIndex are transient because redundant information is stored diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 2cba1febe875..10451a324b0f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -269,7 +269,7 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) tries = 0 // if we don't have enough partition groups, create duplicates while (numCreated < targetLen) { - var (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) + val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) tries += 1 val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 14331dfd0c98..943abae17a91 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -17,8 +17,8 @@ package org.apache.spark.rdd -import org.apache.spark.annotation.Since import org.apache.spark.TaskContext +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.MeanEvaluator @@ -128,9 +128,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } // Compute the minimum and the maximum val (max: Double, min: Double) = self.mapPartitions { items => - Iterator(items.foldRight(Double.NegativeInfinity, - Double.PositiveInfinity)((e: Double, x: (Double, Double)) => - (x._1.max(e), x._2.min(e)))) + Iterator( + items.foldRight((Double.NegativeInfinity, Double.PositiveInfinity) + )((e: Double, x: (Double, Double)) => (x._1.max(e), x._2.min(e)))) }.reduce { (maxmin1, maxmin2) => (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 4bf8ecc38354..76ea8b86c53d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -251,7 +251,13 @@ class HadoopRDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener{ context => closeIfNeeded() } + context.addTaskCompletionListener { context => + // Update the bytes read before closing is to make sure lingering bytesRead statistics in + // this thread get correctly added. + updateBytesRead() + closeIfNeeded() + } + private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ce3a9a2a1e2a..482875e6c1ac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -191,7 +191,13 @@ class NewHadoopRDD[K, V]( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) + context.addTaskCompletionListener { context => + // Update the bytesRead before closing is to make sure lingering bytesRead statistics in + // this thread get correctly added. + updateBytesRead() + close() + } + private var havePair = false private var recordsSinceMetricsUpdate = 0 diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 58762cc0838c..e68c6b1366c7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} @@ -36,13 +35,11 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewO import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter, - SparkHadoopWriterUtils} import org.apache.spark.internal.Logging +import org.apache.spark.internal.io._ import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -1082,9 +1079,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { - SparkHadoopMapReduceWriter.write( + val config = new HadoopMapReduceWriteConfigUtil[K, V](new SerializableConfiguration(conf)) + SparkHadoopWriter.write( rdd = self, - hadoopConf = conf) + config = config) } /** @@ -1094,62 +1092,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * MapReduce job. */ def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { - // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). - val hadoopConf = conf - val outputFormatInstance = hadoopConf.getOutputFormat - val keyClass = hadoopConf.getOutputKeyClass - val valueClass = hadoopConf.getOutputValueClass - if (outputFormatInstance == null) { - throw new SparkException("Output format class not set") - } - if (keyClass == null) { - throw new SparkException("Output key class not set") - } - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - SparkHadoopUtil.get.addCredentials(hadoopConf) - - logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + - valueClass.getSimpleName + ")") - - if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) { - // FileOutputFormat ignores the filesystem parameter - val ignoredFs = FileSystem.get(hadoopConf) - hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) - } - - val writer = new SparkHadoopWriter(hadoopConf) - writer.preSetup() - - val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - - val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) - - writer.setup(context.stageId, context.partitionId, taskAttemptId) - writer.open() - var recordsWritten = 0L - - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iter.hasNext) { - val record = iter.next() - writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) - - // Update bytes written metric every few records - SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) - recordsWritten += 1 - } - }(finallyBlock = writer.close()) - writer.commit() - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - } - - self.context.runJob(self, writeToFile) - writer.commitJob() + val config = new HadoopMapRedWriteConfigUtil[K, V](new SerializableJobConf(conf)) + SparkHadoopWriter.write( + rdd = self, + config = config) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index 6a89ea878646..15691a8fc8ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -22,8 +22,8 @@ import java.util.Random import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} -import org.apache.spark.util.random.RandomSampler import org.apache.spark.util.Utils +import org.apache.spark.util.random.RandomSampler private[spark] class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 63a87e7f09d8..8798dfc92536 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.io.Codec import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} +import scala.util.hashing import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.io.{BytesWritable, NullWritable, Text} @@ -55,7 +56,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * Doubles; and * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that * can be saved as SequenceFiles. - * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] + * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)]) * through implicit. * * Internally, each RDD is characterized by five main properties: @@ -448,7 +449,7 @@ abstract class RDD[T: ClassTag]( if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { - var position = (new Random(index)).nextInt(numPartitions) + var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. @@ -1118,9 +1119,9 @@ abstract class RDD[T: ClassTag]( /** * Aggregates the elements of this RDD in a multi-level tree pattern. + * This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]]. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] */ def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, @@ -1134,7 +1135,7 @@ abstract class RDD[T: ClassTag]( val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var partiallyAggregated: RDD[U] = mapPartitions(it => Iterator(aggregatePartition(it))) var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce @@ -1146,9 +1147,10 @@ abstract class RDD[T: ClassTag]( val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + }.foldByKey(zeroValue, new HashPartitioner(curNumPartitions))(cleanCombOp).values } - partiallyAggregated.reduce(cleanCombOp) + val copiedZeroValue = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + partiallyAggregated.fold(copiedZeroValue)(cleanCombOp) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 37c67cee55f9..979152b55f95 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -152,8 +152,10 @@ private[spark] object ReliableCheckpointRDD extends Logging { sc, checkpointDirPath.toString, originalRDD.partitioner) if (newRDD.partitions.length != originalRDD.partitions.length) { throw new SparkException( - s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + - s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})") + "Checkpoint RDD has a different number of partitions from original RDD. Original " + + s"RDD [ID: ${originalRDD.id}, num of partitions: ${originalRDD.partitions.length}]; " + + s"Checkpoint RDD [ID: ${newRDD.id}, num of partitions: " + + s"${newRDD.partitions.length}].") } newRDD } diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 86a332790fb0..02def89dd8c2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.rdd -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import org.apache.hadoop.io.Writable import org.apache.hadoop.io.compress.CompressionCodec @@ -39,40 +39,8 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag extends Logging with Serializable { - private val keyWritableClass = - if (_keyWritableClass == null) { - // pre 1.3.0, we need to use Reflection to get the Writable class - getWritableClass[K]() - } else { - _keyWritableClass - } - - private val valueWritableClass = - if (_valueWritableClass == null) { - // pre 1.3.0, we need to use Reflection to get the Writable class - getWritableClass[V]() - } else { - _valueWritableClass - } - - private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = { - val c = { - if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) { - classTag[T].runtimeClass - } else { - // We get the type of the Writable class by looking at the apply method which converts - // from T to Writable. Since we have two apply methods we filter out the one which - // is not of the form "java.lang.Object apply(java.lang.Object)" - implicitly[T => Writable].getClass.getDeclaredMethods().filter( - m => m.getReturnType().toString != "class java.lang.Object" && - m.getName() == "apply")(0).getReturnType - - } - // TODO: use something like WritableConverter to avoid reflection - } - c.asInstanceOf[Class[_ <: Writable]] - } - + // TODO the context bound (<%) above should be replaced with simple type bound and implicit + // conversion but is a breaking change. This should be fixed in Spark 3.x. /** * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key @@ -90,24 +58,24 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag // valueWritableClass at the compile time. To implement that, we need to add type parameters to // SequenceFileRDDFunctions. however, SequenceFileRDDFunctions is a public class so it will be a // breaking change. - val convertKey = self.keyClass != keyWritableClass - val convertValue = self.valueClass != valueWritableClass + val convertKey = self.keyClass != _keyWritableClass + val convertValue = self.valueClass != _valueWritableClass - logInfo("Saving as sequence file of type (" + keyWritableClass.getSimpleName + "," + - valueWritableClass.getSimpleName + ")" ) + logInfo("Saving as sequence file of type " + + s"(${_keyWritableClass.getSimpleName},${_valueWritableClass.getSimpleName})" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) + self.saveAsHadoopFile(path, _keyWritableClass, _valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { self.map(x => (x._1, anyToWritable(x._2))).saveAsHadoopFile( - path, keyWritableClass, valueWritableClass, format, jobConf, codec) + path, _keyWritableClass, _valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { self.map(x => (anyToWritable(x._1), x._2)).saveAsHadoopFile( - path, keyWritableClass, valueWritableClass, format, jobConf, codec) + path, _keyWritableClass, _valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { self.map(x => (anyToWritable(x._1), anyToWritable(x._2))).saveAsHadoopFile( - path, keyWritableClass, valueWritableClass, format, jobConf, codec) + path, _keyWritableClass, _valueWritableClass, format, jobConf, codec) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index 8e1baae796fc..9f3d0745c33c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.{Text, Writable} import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} @@ -38,8 +39,12 @@ private[spark] class WholeTextFileRDD( extends NewHadoopRDD[Text, Text](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance val conf = getConf + // setMinPartitions below will call FileInputFormat.listStatus(), which can be quite slow when + // traversing a large number of directories and files. Parallelize it. + conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, + Runtime.getRuntime.availableProcessors().toString) + val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala index e00bc22aba44..1f8ab784a92b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala +++ b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala @@ -19,8 +19,8 @@ package org.apache.spark.rdd import scala.collection.mutable -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Partition +import org.apache.spark.annotation.DeveloperApi /** * ::DeveloperApi:: diff --git a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index ab72addb2466..facbb830a60d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -50,6 +50,7 @@ import org.apache.spark.util.PeriodicCheckpointer * {{{ * val (rdd1, rdd2, rdd3, ...) = ... * val cp = new PeriodicRDDCheckpointer(2, sc) + * cp.update(rdd1) * rdd1.count(); * // persisted: rdd1 * cp.update(rdd2) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 530743c03640..de2cc56bc6b1 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -40,7 +40,7 @@ private[spark] object RpcEnv { conf: SparkConf, securityManager: SecurityManager, clientMode: Boolean = false): RpcEnv = { - create(name, host, host, port, conf, securityManager, clientMode) + create(name, host, host, port, conf, securityManager, 0, clientMode) } def create( @@ -50,9 +50,10 @@ private[spark] object RpcEnv { port: Int, conf: SparkConf, securityManager: SecurityManager, + numUsableCores: Int, clientMode: Boolean): RpcEnv = { val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager, - clientMode) + numUsableCores, clientMode) new NettyRpcEnvFactory().create(config) } } @@ -201,4 +202,5 @@ private[spark] case class RpcEnvConfig( advertiseAddress: String, port: Int, securityManager: SecurityManager, + numUsableCores: Int, clientMode: Boolean) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 0557b7a3cc0b..3dc41f7f1279 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -125,9 +125,9 @@ private[spark] object RpcTimeout { var foundProp: Option[(String, String)] = None while (itr.hasNext && foundProp.isEmpty) { val propKey = itr.next() - conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + conf.getOption(propKey).foreach { prop => foundProp = Some((propKey, prop)) } } - val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val finalProp = foundProp.getOrElse((timeoutPropList.head, defaultValue)) val timeout = { Utils.timeStringAsSeconds(finalProp._2).seconds } new RpcTimeout(timeout, finalProp._1) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index a02cf30a5d83..904c4d02dd2a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -32,8 +32,11 @@ import org.apache.spark.util.ThreadUtils /** * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). + * + * @param numUsableCores Number of CPU cores allocated to the process, for sizing the thread pool. + * If 0, will consider the available CPUs on the host. */ -private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { +private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging { private class EndpointData( val name: String, @@ -109,8 +112,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}")) - } + postMessage(name, message, (e) => { e match { + case e: RpcEnvStoppedException => logDebug (s"Message $message dropped. ${e.getMessage}") + case e: Throwable => logWarning(s"Message $message dropped. ${e.getMessage}") + }} + )} } /** Posts a message sent by a remote endpoint. */ @@ -189,8 +195,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** Thread pool used for dispatching messages. */ private val threadpool: ThreadPoolExecutor = { + val availableCores = + if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", - math.max(2, Runtime.getRuntime.availableProcessors())) + math.max(2, availableCores)) val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") for (i <- 0 until numThreads) { pool.execute(new MessageLoop) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index ae4a6003517c..d32eba64e13e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -205,7 +205,12 @@ private[netty] class Inbox( try action catch { case NonFatal(e) => try endpoint.onError(e) catch { - case NonFatal(ee) => logError(s"Ignoring error", ee) + case NonFatal(ee) => + if (stopped) { + logDebug("Ignoring error", ee) + } else { + logError("Ignoring error", ee) + } } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index b316e5443f63..f951591e02a5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -44,14 +44,15 @@ private[netty] class NettyRpcEnv( val conf: SparkConf, javaSerializerInstance: JavaSerializerInstance, host: String, - securityManager: SecurityManager) extends RpcEnv(conf) with Logging { + securityManager: SecurityManager, + numUsableCores: Int) extends RpcEnv(conf) with Logging { private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0)) - private val dispatcher: Dispatcher = new Dispatcher(this) + private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) private val streamManager = new NettyStreamManager(this) @@ -185,7 +186,7 @@ private[netty] class NettyRpcEnv( try { dispatcher.postOneWayMessage(message) } catch { - case e: RpcEnvStoppedException => logWarning(e.getMessage) + case e: RpcEnvStoppedException => logDebug(e.getMessage) } } else { // Message to a remote RPC endpoint. @@ -203,7 +204,10 @@ private[netty] class NettyRpcEnv( def onFailure(e: Throwable): Unit = { if (!promise.tryFailure(e)) { - logWarning(s"Ignored failure: $e") + e match { + case e : RpcEnvStoppedException => logDebug (s"Ignored failure: $e") + case _ => logWarning(s"Ignored failure: $e") + } } } @@ -228,7 +232,7 @@ private[netty] class NettyRpcEnv( onFailure, (client, response) => onSuccess(deserialize[Any](client, response))) postToOutbox(message.receiver, rpcMessage) - promise.future.onFailure { + promise.future.failed.foreach { case _: TimeoutException => rpcMessage.onTimeout() case _ => }(ThreadUtils.sameThread) @@ -448,7 +452,7 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress, - config.securityManager) + config.securityManager, config.numUsableCores) if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(config.bindAddress, actualPort) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index a7b7f58376f6..b7e068aa6835 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -45,7 +45,7 @@ private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends Outbo override def onFailure(e: Throwable): Unit = { e match { - case e1: RpcEnvStoppedException => logWarning(e1.getMessage) + case e1: RpcEnvStoppedException => logDebug(e1.getMessage) case e1: Throwable => logWarning(s"Failed to send one-way RPC.", e1) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 28c45d800ed0..6da8865cd10d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -34,6 +34,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var adminAcls: Option[String] = None var viewAclsGroups: Option[String] = None var adminAclsGroups: Option[String] = None + var appSparkVersion: Option[String] = None override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { appName = Some(applicationStart.appName) @@ -57,4 +58,10 @@ private[spark] class ApplicationEventListener extends SparkListener { adminAclsGroups = allProperties.get("spark.admin.acls.groups") } } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(sparkVersion) => + appSparkVersion = Some(sparkVersion) + case _ => + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala new file mode 100644 index 000000000000..8605e1da161c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +import com.codahale.metrics.{Gauge, Timer} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +/** + * An asynchronous queue for events. All events posted to this queue will be delivered to the child + * listeners in a separate thread. + * + * Delivery will only begin when the `start()` method is called. The `stop()` method should be + * called when no more events need to be delivered. + */ +private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) + extends SparkListenerBus + with Logging { + + import AsyncEventQueue._ + + // Cap the capacity of the queue so we get an explicit error (rather than an OOM exception) if + // it's perpetually being added to more quickly than it's being drained. + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]( + conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) + + // Keep the event count separately, so that waitUntilEmpty() can be implemented properly; + // this allows that method to return only when the events in the queue have been fully + // processed (instead of just dequeued). + private val eventCount = new AtomicLong() + + /** A counter for dropped events. It will be reset every time we log it. */ + private val droppedEventsCounter = new AtomicLong(0L) + + /** When `droppedEventsCounter` was logged last time in milliseconds. */ + @volatile private var lastReportTimestamp = 0L + + private val logDroppedEvent = new AtomicBoolean(false) + + private var sc: SparkContext = null + + private val started = new AtomicBoolean(false) + private val stopped = new AtomicBoolean(false) + + private val droppedEvents = metrics.metricRegistry.counter(s"queue.$name.numDroppedEvents") + private val processingTime = metrics.metricRegistry.timer(s"queue.$name.listenerProcessingTime") + + // Remove the queue size gauge first, in case it was created by a previous incarnation of + // this queue that was removed from the listener bus. + metrics.metricRegistry.remove(s"queue.$name.size") + metrics.metricRegistry.register(s"queue.$name.size", new Gauge[Int] { + override def getValue: Int = eventQueue.size() + }) + + private val dispatchThread = new Thread(s"spark-listener-group-$name") { + setDaemon(true) + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + dispatch() + } + } + + private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { + try { + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() + } + eventCount.decrementAndGet() + next = eventQueue.take() + } + eventCount.decrementAndGet() + } catch { + case ie: InterruptedException => + logInfo(s"Stopping listener queue $name.", ie) + } + } + + override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { + metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface])) + } + + /** + * Start an asynchronous thread to dispatch events to the underlying listeners. + * + * @param sc Used to stop the SparkContext in case the async dispatcher fails. + */ + private[scheduler] def start(sc: SparkContext): Unit = { + if (started.compareAndSet(false, true)) { + this.sc = sc + dispatchThread.start() + } else { + throw new IllegalStateException(s"$name already started!") + } + } + + /** + * Stop the listener bus. It will wait until the queued events have been processed, but new + * events will be dropped. + */ + private[scheduler] def stop(): Unit = { + if (!started.get()) { + throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + } + if (stopped.compareAndSet(false, true)) { + eventQueue.put(POISON_PILL) + eventCount.incrementAndGet() + } + dispatchThread.join() + } + + def post(event: SparkListenerEvent): Unit = { + if (stopped.get()) { + return + } + + eventCount.incrementAndGet() + if (eventQueue.offer(event)) { + return + } + + eventCount.decrementAndGet() + droppedEvents.inc() + droppedEventsCounter.incrementAndGet() + if (logDroppedEvent.compareAndSet(false, true)) { + // Only log the following message once to avoid duplicated annoying logs. + logError(s"Dropping event from queue $name. " + + "This likely means one of the listeners is too slow and cannot keep up with " + + "the rate at which tasks are being started by the scheduler.") + } + logTrace(s"Dropping event $event") + + val droppedCount = droppedEventsCounter.get + if (droppedCount > 0) { + // Don't log too frequently + if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { + // There may be multiple threads trying to decrease droppedEventsCounter. + // Use "compareAndSet" to make sure only one thread can win. + // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and + // then that thread will update it. + if (droppedEventsCounter.compareAndSet(droppedCount, 0)) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + val previous = new java.util.Date(prevLastReportTimestamp) + logWarning(s"Dropped $droppedEvents events from $name since $previous.") + } + } + } + } + + /** + * For testing only. Wait until there are no more events in the queue. + * + * @return true if the queue is empty. + */ + def waitUntilEmpty(deadline: Long): Boolean = { + while (eventCount.get() != 0) { + if (System.currentTimeMillis > deadline) { + return false + } + Thread.sleep(10) + } + true + } + +} + +private object AsyncEventQueue { + + val POISON_PILL = new SparkListenerEvent() { } + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index e130e609e4f6..cd8e61d6d020 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -61,6 +61,7 @@ private[scheduler] class BlacklistTracker ( private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) + private val BLACKLIST_FETCH_FAILURE_ENABLED = conf.get(config.BLACKLIST_FETCH_FAILURE_ENABLED) /** * A map from executorId to information on task failures. Tracks the time of each task failure, @@ -145,6 +146,74 @@ private[scheduler] class BlacklistTracker ( nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) } + private def killBlacklistedExecutor(exec: String): Unit = { + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(a) => + logInfo(s"Killing blacklisted executor id $exec " + + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") + a.killExecutors(Seq(exec), true, true) + case None => + logWarning(s"Not attempting to kill blacklisted executor id $exec " + + s"since allocation client is not defined.") + } + } + } + + private def killExecutorsOnBlacklistedNode(node: String): Unit = { + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(a) => + logInfo(s"Killing all executors on blacklisted host $node " + + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") + if (a.killExecutorsOnHost(node) == false) { + logError(s"Killing executors on node $node failed.") + } + case None => + logWarning(s"Not attempting to kill executors on blacklisted host $node " + + s"since allocation client is not defined.") + } + } + } + + def updateBlacklistForFetchFailure(host: String, exec: String): Unit = { + if (BLACKLIST_FETCH_FAILURE_ENABLED) { + // If we blacklist on fetch failures, we are implicitly saying that we believe the failure is + // non-transient, and can't be recovered from (even if this is the first fetch failure, + // stage is retried after just one failure, so we don't always get a chance to collect + // multiple fetch failures). + // If the external shuffle-service is on, then every other executor on this node would + // be suffering from the same issue, so we should blacklist (and potentially kill) all + // of them immediately. + + val now = clock.getTimeMillis() + val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS + + if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + if (!nodeIdToBlacklistExpiryTime.contains(host)) { + logInfo(s"blacklisting node $host due to fetch failure of external shuffle service") + + nodeIdToBlacklistExpiryTime.put(host, expiryTimeForNewBlacklists) + listenerBus.post(SparkListenerNodeBlacklisted(now, host, 1)) + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + killExecutorsOnBlacklistedNode(host) + updateNextExpiryTime() + } + } else if (!executorIdToBlacklistStatus.contains(exec)) { + logInfo(s"Blacklisting executor $exec due to fetch failure") + + executorIdToBlacklistStatus.put(exec, BlacklistedExecutor(host, expiryTimeForNewBlacklists)) + // We hardcoded number of failure tasks to 1 for fetch failure, because there's no + // reattempt for such failure. + listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, 1)) + updateNextExpiryTime() + killBlacklistedExecutor(exec) + + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + blacklistedExecsOnNode += exec + } + } + } def updateBlacklistForSuccessfulTaskSet( stageId: Int, @@ -174,17 +243,7 @@ private[scheduler] class BlacklistTracker ( listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, newTotal)) executorIdToFailureList.remove(exec) updateNextExpiryTime() - if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(allocationClient) => - logInfo(s"Killing blacklisted executor id $exec " + - s"since spark.blacklist.killBlacklistedExecutors is set.") - allocationClient.killExecutors(Seq(exec), true, true) - case None => - logWarning(s"Not attempting to kill blacklisted executor id $exec " + - s"since allocation client is not defined.") - } - } + killBlacklistedExecutor(exec) // In addition to blacklisting the executor, we also update the data for failures on the // node, and potentially put the entire node into a blacklist as well. @@ -199,19 +258,7 @@ private[scheduler] class BlacklistTracker ( nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size)) _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) - if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(allocationClient) => - logInfo(s"Killing all executors on blacklisted host $node " + - s"since spark.blacklist.killBlacklistedExecutors is set.") - if (allocationClient.killExecutorsOnHost(node) == false) { - logError(s"Killing executors on node $node failed.") - } - case None => - logWarning(s"Not attempting to kill executors on blacklisted host $node " + - s"since allocation client is not defined.") - } - } + killExecutorsOnBlacklistedNode(node) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index aab177f257a8..9153751d03c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.annotation.tailrec import scala.collection.Map -import scala.collection.mutable.{HashMap, HashSet, Stack} +import scala.collection.mutable.{ArrayStack, HashMap, HashSet} import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -36,6 +36,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging +import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD @@ -58,7 +59,7 @@ import org.apache.spark.util._ * set of map output files, and another to read those files after a barrier). In the end, every * stage will have only shuffle dependencies on other stages, and may compute multiple operations * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of - * various RDDs (MappedRDD, FilteredRDD, etc). + * various RDDs * * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the @@ -187,6 +188,14 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + /** + * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, + * this is set default to false, which means, we only unregister the outputs related to the exact + * executor(instead of the host) on a FetchFailure. + */ + private[scheduler] val unRegisterOutputOnHostOnFetchFailure = + sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) + /** * Number of consecutive stage attempts allowed before a stage is aborted. */ @@ -250,6 +259,13 @@ class DAGScheduler( eventProcessLoop.post(ExecutorLost(execId, reason)) } + /** + * Called by TaskScheduler implementation when a worker is removed. + */ + def workerRemoved(workerId: String, host: String, message: String): Unit = { + eventProcessLoop.post(WorkerRemoved(workerId, host, message)) + } + /** * Called by TaskScheduler implementation when a host is added. */ @@ -265,6 +281,13 @@ class DAGScheduler( eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } + /** + * Called by the TaskSetManager when it decides a speculative task is needed. + */ + def speculativeTaskSubmitted(task: Task[_]): Unit = { + eventProcessLoop.post(SpeculativeTaskSubmitted(task)) + } + private[scheduler] def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times @@ -328,25 +351,14 @@ class DAGScheduler( val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep) + val stage = new ShuffleMapStage( + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) stageIdToStage(id) = stage shuffleIdToMapStage(shuffleDep.shuffleId) = stage updateJobIdStageIdMaps(jobId, stage) - if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { - // A previously run stage generated partitions for this shuffle, so for each output - // that's still available, copy information about that output location to the new stage - // (so we don't unnecessarily re-compute that data). - val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) - val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - (0 until locs.length).foreach { i => - if (locs(i) ne null) { - // locs(i) will be null if missing - stage.addOutputLoc(i, locs(i)) - } - } - } else { + if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") @@ -384,12 +396,12 @@ class DAGScheduler( /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getMissingAncestorShuffleDependencies( - rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { - val ancestors = new Stack[ShuffleDependency[_, _, _]] + rdd: RDD[_]): ArrayStack[ShuffleDependency[_, _, _]] = { + val ancestors = new ArrayStack[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting - val waitingForVisit = new Stack[RDD[_]] + val waitingForVisit = new ArrayStack[RDD[_]] waitingForVisit.push(rdd) while (waitingForVisit.nonEmpty) { val toVisit = waitingForVisit.pop() @@ -422,7 +434,7 @@ class DAGScheduler( rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = { val parents = new HashSet[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] - val waitingForVisit = new Stack[RDD[_]] + val waitingForVisit = new ArrayStack[RDD[_]] waitingForVisit.push(rdd) while (waitingForVisit.nonEmpty) { val toVisit = waitingForVisit.pop() @@ -444,7 +456,7 @@ class DAGScheduler( val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting - val waitingForVisit = new Stack[RDD[_]] + val waitingForVisit = new ArrayStack[RDD[_]] def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd @@ -618,12 +630,7 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`, - // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that - // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's - // safe to pass in null here. For more detail, see SPARK-13747. - val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - waiter.completionFuture.ready(Duration.Inf)(awaitPermission) + ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf) waiter.completionFuture.value.get match { case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format @@ -812,6 +819,10 @@ class DAGScheduler( listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) } + private[scheduler] def handleSpeculativeTaskSubmitted(task: Task[_]): Unit = { + listenerBus.post(SparkListenerSpeculativeTaskSubmitted(task.stageId)) + } + private[scheduler] def handleTaskSetFailed( taskSet: TaskSet, reason: String, @@ -988,6 +999,13 @@ class DAGScheduler( } stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) + + // If there are tasks to execute, record the submission time of the stage. Otherwise, + // post the even without the submission time, which indicates that this stage was + // skipped. + if (partitionsToCompute.nonEmpty) { + stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) + } listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -1059,7 +1077,6 @@ class DAGScheduler( s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) - stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run @@ -1116,6 +1133,25 @@ class DAGScheduler( } } + private def postTaskEnd(event: CompletionEvent): Unit = { + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulators(event.accumUpdates) + } catch { + case NonFatal(e) => + val taskId = event.taskInfo.taskId + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + listenerBus.post(SparkListenerTaskEnd(event.task.stageId, event.task.stageAttemptId, + Utils.getFormattedClassName(event.task), event.reason, event.taskInfo, taskMetrics)) + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -1132,34 +1168,36 @@ class DAGScheduler( event.taskInfo.attemptNumber, // this is a task attempt number event.reason) - // Reconstruct task metrics. Note: this may be null if the task has failed. - val taskMetrics: TaskMetrics = - if (event.accumUpdates.nonEmpty) { - try { - TaskMetrics.fromAccumulators(event.accumUpdates) - } catch { - case NonFatal(e) => - logError(s"Error when attempting to reconstruct metrics for task $taskId", e) - null - } - } else { - null - } - - // The stage may have already finished when we get this event -- eg. maybe it was a - // speculative task. It is important that we send the TaskEnd event in any case, so listeners - // are properly notified and can chose to handle it. For instance, some listeners are - // doing their own accounting and if they don't get the task end event they think - // tasks are still running when they really aren't. - listenerBus.post(SparkListenerTaskEnd( - stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics)) - if (!stageIdToStage.contains(task.stageId)) { + // The stage may have already finished when we get this event -- eg. maybe it was a + // speculative task. It is important that we send the TaskEnd event in any case, so listeners + // are properly notified and can chose to handle it. For instance, some listeners are + // doing their own accounting and if they don't get the task end event they think + // tasks are still running when they really aren't. + postTaskEnd(event) + // Skip all the actions if the stage has been cancelled. return } val stage = stageIdToStage(task.stageId) + + // Make sure the task's accumulators are updated before any other processing happens, so that + // we can post a task end event before any jobs or stages are updated. The accumulators are + // only updated in certain cases. + event.reason match { + case Success => + stage match { + case rs: ResultStage if rs.activeJob.isEmpty => + // Ignore update if task's job has finished. + case _ => + updateAccumulators(event) + } + case _: ExceptionFailure => updateAccumulators(event) + case _ => + } + postTaskEnd(event) + event.reason match { case Success => task match { @@ -1170,7 +1208,6 @@ class DAGScheduler( resultStage.activeJob match { case Some(job) => if (!job.finished(rt.outputId)) { - updateAccumulators(event) job.finished(rt.outputId) = true job.numFinished += 1 // If the whole job has finished, remove it @@ -1197,7 +1234,6 @@ class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] - updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) @@ -1216,7 +1252,8 @@ class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - shuffleStage.addOutputLoc(smt.partitionId, status) + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) // Remove the task's partition from pending partitions. This may have already been // done above, but will not have been done yet in cases where the task attempt was // from an earlier attempt of the stage (i.e., not the attempt that's currently @@ -1233,16 +1270,14 @@ class DAGScheduler( logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() clearCacheLocs() @@ -1342,13 +1377,26 @@ class DAGScheduler( } // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) } // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch)) + val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && + unRegisterOutputOnHostOnFetchFailure) { + // We had a fetch failure with the external shuffle service, so we + // assume all shuffle data on the node is bad. + Some(bmAddress.host) + } else { + // Unregister shuffle data just for one executor (we don't have any + // reason to believe shuffle data has been lost for the entire host). + None + } + removeExecutorAndUnregisterOutputs( + execId = bmAddress.executorId, + fileLost = true, + hostToUnregisterOutputs = hostToUnregisterOutputs, + maybeEpoch = Some(task.epoch)) } } @@ -1356,8 +1404,7 @@ class DAGScheduler( // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case exceptionFailure: ExceptionFailure => - // Tasks failed with exceptions might still have accumulator updates. - updateAccumulators(event) + // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. @@ -1382,35 +1429,65 @@ class DAGScheduler( */ private[scheduler] def handleExecutorLost( execId: String, - filesLost: Boolean, - maybeEpoch: Option[Long] = None) { + workerLost: Boolean): Unit = { + // if the cluster manager explicitly tells us that the entire worker was lost, then + // we know to unregister shuffle output. (Note that "worker" specifically refers to the process + // from a Standalone cluster, where the shuffle service lives in the Worker.) + val fileLost = workerLost || !env.blockManager.externalShuffleServiceEnabled + removeExecutorAndUnregisterOutputs( + execId = execId, + fileLost = fileLost, + hostToUnregisterOutputs = None, + maybeEpoch = None) + } + + private def removeExecutorAndUnregisterOutputs( + execId: String, + fileLost: Boolean, + hostToUnregisterOutputs: Option[String], + maybeEpoch: Option[Long] = None): Unit = { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { failedEpoch(execId) = currentEpoch logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - - if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { - logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleIdToMapStage) { - stage.removeOutputsOnExecutor(execId) - mapOutputTracker.registerMapOutputs( - shuffleId, - stage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) - } - if (shuffleIdToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() + if (fileLost) { + hostToUnregisterOutputs match { + case Some(host) => + logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch)) + mapOutputTracker.removeOutputsOnHost(host) + case None => + logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) + mapOutputTracker.removeOutputsOnExecutor(execId) } clearCacheLocs() + + } else { + logDebug("Additional executor lost message for %s (epoch %d)".format(execId, currentEpoch)) } - } else { - logDebug("Additional executor lost message for " + execId + - "(epoch " + currentEpoch + ")") } } + /** + * Responds to a worker being removed. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use workerRemoved() to post a loss event from outside. + * + * We will assume that we've lost all shuffle blocks associated with the host if a worker is + * removed, so we will remove them all from MapStatus. + * + * @param workerId identifier of the worker that is removed. + * @param host host of the worker that is removed. + * @param message the reason why the worker is removed. + */ + private[scheduler] def handleWorkerRemoved( + workerId: String, + host: String, + message: String): Unit = { + logInfo("Shuffle files lost for worker %s on host %s".format(workerId, host)) + mapOutputTracker.removeOutputsOnHost(host) + clearCacheLocs() + } + private[scheduler] def handleExecutorAdded(execId: String, host: String) { // remove from failedEpoch(execId) ? if (failedEpoch.contains(execId)) { @@ -1556,7 +1633,7 @@ class DAGScheduler( val visitedRdds = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting - val waitingForVisit = new Stack[RDD[_]] + val waitingForVisit = new ArrayStack[RDD[_]] def visit(rdd: RDD[_]) { if (!visitedRdds(rdd)) { visitedRdds += rdd @@ -1700,15 +1777,21 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.handleExecutorAdded(execId, host) case ExecutorLost(execId, reason) => - val filesLost = reason match { + val workerLost = reason match { case SlaveLost(_, true) => true case _ => false } - dagScheduler.handleExecutorLost(execId, filesLost) + dagScheduler.handleExecutorLost(execId, workerLost) + + case WorkerRemoved(workerId, host, message) => + dagScheduler.handleWorkerRemoved(workerId, host, message) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) + case SpeculativeTaskSubmitted(task) => + dagScheduler.handleSpeculativeTaskSubmitted(task) + case GettingResultEvent(taskInfo) => dagScheduler.handleGetTaskResult(taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index cda0585f154a..54ab8f8b3e1d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -86,8 +86,15 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String, reason: ExecutorLossReason) extends DAGSchedulerEvent +private[scheduler] case class WorkerRemoved(workerId: String, host: String, message: String) + extends DAGSchedulerEvent + private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent + +private[scheduler] +case class SpeculativeTaskSubmitted(task: Task[_]) extends DAGSchedulerEvent + diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index a7dbf87915b2..9dafa0b7646b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.util.EnumSet import java.util.Locale import scala.collection.mutable @@ -28,6 +29,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} import org.apache.hadoop.fs.permission.FsPermission +import org.apache.hadoop.hdfs.DFSOutputStream +import org.apache.hadoop.hdfs.client.HdfsDataOutputStream.SyncFlag import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods._ @@ -96,8 +99,8 @@ private[spark] class EventLoggingListener( } val workingPath = logPath + IN_PROGRESS - val uri = new URI(workingPath) val path = new Path(workingPath) + val uri = path.toUri val defaultFs = FileSystem.getDefaultUri(hadoopConf).getScheme val isDefaultLocal = defaultFs == null || defaultFs == "file" @@ -119,7 +122,7 @@ private[spark] class EventLoggingListener( val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream) val bstream = new BufferedOutputStream(cstream, outputBufferSize) - EventLoggingListener.initEventLog(bstream) + EventLoggingListener.initEventLog(bstream, testing, loggedEvents) fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) writer = Some(new PrintWriter(bstream)) logInfo("Logging events to %s".format(logPath)) @@ -138,7 +141,10 @@ private[spark] class EventLoggingListener( // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) - hadoopDataStream.foreach(_.hflush()) + hadoopDataStream.foreach(ds => ds.getWrappedStream match { + case wrapped: DFSOutputStream => wrapped.hsync(EnumSet.of(SyncFlag.UPDATE_LENGTH)) + case _ => ds.hflush() + }) } if (testing) { loggedEvents += eventJson @@ -283,10 +289,17 @@ private[spark] object EventLoggingListener extends Logging { * * @param logStream Raw output stream to the event log file. */ - def initEventLog(logStream: OutputStream): Unit = { + def initEventLog( + logStream: OutputStream, + testing: Boolean, + loggedEvents: ArrayBuffer[JValue]): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) - val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" + val eventJson = JsonProtocol.logStartToJson(metadata) + val metadataJson = compact(eventJson) + "\n" logStream.write(metadataJson.getBytes(StandardCharsets.UTF_8)) + if (testing && loggedEvents != null) { + loggedEvents += eventJson + } } /** @@ -313,7 +326,7 @@ private[spark] object EventLoggingListener extends Logging { appId: String, appAttemptId: Option[String], compressionCodecName: Option[String] = None): String = { - val base = logBaseDir.toString.stripSuffix("/") + "/" + sanitize(appId) + val base = new Path(logBaseDir).toString.stripSuffix("/") + "/" + sanitize(appId) val codec = compressionCodecName.map("." + _).getOrElse("") if (appAttemptId.isDefined) { base + "_" + sanitize(appAttemptId.get) + codec @@ -338,14 +351,14 @@ private[spark] object EventLoggingListener extends Logging { // Since we sanitize the app ID to not include periods, it is safe to split on it val logName = log.getName.stripSuffix(IN_PROGRESS) val codecName: Option[String] = logName.split("\\.").tail.lastOption - val codec = codecName.map { c => - codecMap.getOrElseUpdate(c, CompressionCodec.createCodec(new SparkConf, c)) - } try { + val codec = codecName.map { c => + codecMap.getOrElseUpdate(c, CompressionCodec.createCodec(new SparkConf, c)) + } codec.map(_.compressedInputStream(in)).getOrElse(in) } catch { - case e: Exception => + case e: Throwable => in.close() throw e } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 5533f7b1f236..2f93c497c577 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,14 +17,22 @@ package org.apache.spark.scheduler +import java.util.{List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.reflect.ClassTag import scala.util.DynamicVariable -import org.apache.spark.{SparkContext, SparkException} +import com.codahale.metrics.{Counter, MetricRegistry, Timer} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.util.Utils +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. @@ -33,24 +41,13 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends SparkListenerBus { - - self => +private[spark] class LiveListenerBus(conf: SparkConf) { import LiveListenerBus._ - // Cap the capacity of the event queue so we get an explicit error (rather than - // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private lazy val EVENT_QUEUE_CAPACITY = validateAndGetQueueSize() - private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + private var sparkContext: SparkContext = _ - private def validateAndGetQueueSize(): Int = { - val queueSize = sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_SIZE) - if (queueSize <= 0) { - throw new SparkException("spark.scheduler.listenerbus.eventqueue.size must be > 0!") - } - queueSize - } + private[spark] val metrics = new LiveListenerBusMetrics(conf) // Indicate if `start()` is called private val started = new AtomicBoolean(false) @@ -63,41 +60,74 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa /** When `droppedEventsCounter` was logged last time in milliseconds. */ @volatile private var lastReportTimestamp = 0L - // Indicate if we are processing some event - // Guarded by `self` - private var processingEvent = false - - private val logDroppedEvent = new AtomicBoolean(false) - - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) - - private val listenerThread = new Thread(name) { - setDaemon(true) - override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - LiveListenerBus.withinListenerThread.withValue(true) { - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - postToAll(event) - } finally { - self.synchronized { - processingEvent = false - } - } + private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + + /** Add a listener to queue shared by all non-internal listeners. */ + def addToSharedQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, SHARED_QUEUE) + } + + /** Add a listener to the executor management queue. */ + def addToManagementQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, EXECUTOR_MANAGEMENT_QUEUE) + } + + /** Add a listener to the application status queue. */ + def addToStatusQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, APP_STATUS_QUEUE) + } + + /** Add a listener to the event log queue. */ + def addToEventLogQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, EVENT_LOG_QUEUE) + } + + /** + * Add a listener to a specific queue, creating a new queue if needed. Queues are independent + * of each other (each one uses a separate thread for delivering events), allowing slower + * listeners to be somewhat isolated from others. + */ + private def addToQueue(listener: SparkListenerInterface, queue: String): Unit = synchronized { + if (stopped.get()) { + throw new IllegalStateException("LiveListenerBus is stopped.") + } + + queues.asScala.find(_.name == queue) match { + case Some(queue) => + queue.addListener(listener) + + case None => + val newQueue = new AsyncEventQueue(queue, conf, metrics) + newQueue.addListener(listener) + if (started.get()) { + newQueue.start(sparkContext) } + queues.add(newQueue) + } + } + + def removeListener(listener: SparkListenerInterface): Unit = synchronized { + // Remove listener from all queues it was added to, and stop queues that have become empty. + queues.asScala + .filter { queue => + queue.removeListener(listener) + queue.listeners.isEmpty() + } + .foreach { toRemove => + if (started.get() && !stopped.get()) { + toRemove.stop() + } + queues.remove(toRemove) + } + } + + /** Post an event to all queues. */ + def post(event: SparkListenerEvent): Unit = { + if (!stopped.get()) { + metrics.numEventsPosted.inc() + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) } } } @@ -109,45 +139,16 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa * listens for any additional events asynchronously while the listener bus is still running. * This should only be called once. * + * @param sc Used to stop the SparkContext in case the listener thread dies. */ - def start(): Unit = { - if (started.compareAndSet(false, true)) { - listenerThread.start() - } else { - throw new IllegalStateException(s"$name already started!") - } - } - - def post(event: SparkListenerEvent): Unit = { - if (stopped.get) { - // Drop further events to make `listenerThread` exit ASAP - logError(s"$name has already stopped! Dropping event $event") - return - } - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - onDropEvent(event) - droppedEventsCounter.incrementAndGet() + def start(sc: SparkContext, metricsSystem: MetricsSystem): Unit = synchronized { + if (!started.compareAndSet(false, true)) { + throw new IllegalStateException("LiveListenerBus already started.") } - val droppedEvents = droppedEventsCounter.get - if (droppedEvents > 0) { - // Don't log too frequently - if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { - // There may be multiple threads trying to decrease droppedEventsCounter. - // Use "compareAndSet" to make sure only one thread can win. - // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and - // then that thread will update it. - if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) { - val prevLastReportTimestamp = lastReportTimestamp - lastReportTimestamp = System.currentTimeMillis() - logWarning(s"Dropped $droppedEvents SparkListenerEvents since " + - new java.util.Date(prevLastReportTimestamp)) - } - } - } + this.sparkContext = sc + queues.asScala.foreach(_.start(sc)) + metricsSystem.registerSource(metrics) } /** @@ -158,71 +159,100 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa */ @throws(classOf[TimeoutException]) def waitUntilEmpty(timeoutMillis: Long): Unit = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!queueIsEmpty) { - if (System.currentTimeMillis > finishTime) { - throw new TimeoutException( - s"The event queue is not empty after $timeoutMillis milliseconds") + val deadline = System.currentTimeMillis + timeoutMillis + queues.asScala.foreach { queue => + if (!queue.waitUntilEmpty(deadline)) { + throw new TimeoutException(s"The event queue is not empty after $timeoutMillis ms.") } - /* Sleep rather than using wait/notify, because this is used only for testing and - * wait/notify add overhead in the general case. */ - Thread.sleep(10) } } - /** - * For testing only. Return whether the listener daemon thread is still alive. - * Exposed for testing. - */ - def listenerThreadIsAlive: Boolean = listenerThread.isAlive - - /** - * Return whether the event queue is empty. - * - * The use of synchronized here guarantees that all events that once belonged to this queue - * have already been processed by all attached listeners, if this returns true. - */ - private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } - /** * Stop the listener bus. It will wait until the queued events have been processed, but drop the * new events after stopping. */ def stop(): Unit = { if (!started.get()) { - throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + throw new IllegalStateException(s"Attempted to stop bus that has not yet started!") } - if (stopped.compareAndSet(false, true)) { - // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know - // `stop` is called. - eventLock.release() - listenerThread.join() - } else { - // Keep quiet + + if (!stopped.compareAndSet(false, true)) { + return } - } - /** - * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be - * notified with the dropped events. - * - * Note: `onDropEvent` can be called in any thread. - */ - def onDropEvent(event: SparkListenerEvent): Unit = { - if (logDroppedEvent.compareAndSet(false, true)) { - // Only log the following message once to avoid duplicated annoying logs. - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with " + - "the rate at which tasks are being started by the scheduler.") + synchronized { + queues.asScala.foreach(_.stop()) + queues.clear() } } + + // For testing only. + private[spark] def findListenersByClass[T <: SparkListenerInterface : ClassTag](): Seq[T] = { + queues.asScala.flatMap { queue => queue.findListenersByClass[T]() } + } + + // For testing only. + private[spark] def listeners: JList[SparkListenerInterface] = { + queues.asScala.flatMap(_.listeners.asScala).asJava + } + + // For testing only. + private[scheduler] def activeQueues(): Set[String] = { + queues.asScala.map(_.name).toSet + } + } private[spark] object LiveListenerBus { // Allows for Context to check whether stop() call is made within listener thread val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) - /** The thread name of Spark listener bus */ - val name = "SparkListenerBus" + private[scheduler] val SHARED_QUEUE = "shared" + + private[scheduler] val APP_STATUS_QUEUE = "appStatus" + + private[scheduler] val EXECUTOR_MANAGEMENT_QUEUE = "executorManagement" + + private[scheduler] val EVENT_LOG_QUEUE = "eventLog" } +private[spark] class LiveListenerBusMetrics(conf: SparkConf) + extends Source with Logging { + + override val sourceName: String = "LiveListenerBus" + override val metricRegistry: MetricRegistry = new MetricRegistry + + /** + * The total number of events posted to the LiveListenerBus. This is a count of the total number + * of events which have been produced by the application and sent to the listener bus, NOT a + * count of the number of events which have been processed and delivered to listeners (or dropped + * without being delivered). + */ + val numEventsPosted: Counter = metricRegistry.counter(MetricRegistry.name("numEventsPosted")) + + // Guarded by synchronization. + private val perListenerClassTimers = mutable.Map[String, Timer]() + + /** + * Returns a timer tracking the processing time of the given listener class. + * events processed by that listener. This method is thread-safe. + */ + def getTimerForListenerClass(cls: Class[_ <: SparkListenerInterface]): Option[Timer] = { + synchronized { + val className = cls.getName + val maxTimed = conf.get(LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED) + perListenerClassTimers.get(className).orElse { + if (perListenerClassTimers.size == maxTimed) { + logError(s"Not measuring processing time for listener class $className because a " + + s"maximum of $maxTimed listener classes are already timed.") + None + } else { + perListenerClassTimers(className) = + metricRegistry.timer(MetricRegistry.name("listenerProcessingTime", className)) + perListenerClassTimers.get(className) + } + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index b2e9a97129f0..5e45b375ddd4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,8 +19,13 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.roaringbitmap.RoaringBitmap +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -121,34 +126,41 @@ private[spark] class CompressedMapStatus( } /** - * A [[MapStatus]] implementation that only stores the average size of non-empty blocks, + * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger + * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks, * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed * @param numNonEmptyBlocks the number of non-empty blocks * @param emptyBlocks a bitmap tracking which blocks are empty - * @param avgSize average size of the non-empty blocks + * @param avgSize average size of the non-empty and non-huge blocks + * @param hugeBlockSizes sizes of huge blocks by their reduceId. */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, - private[this] var avgSize: Long) + private[this] var avgSize: Long, + private var hugeBlockSizes: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization - require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0, + require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc override def getSizeForBlock(reduceId: Int): Long = { + assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { 0 } else { - avgSize + hugeBlockSizes.get(reduceId) match { + case Some(size) => MapStatus.decompressSize(size) + case None => avgSize + } } } @@ -156,6 +168,11 @@ private[spark] class HighlyCompressedMapStatus private ( loc.writeExternal(out) emptyBlocks.writeExternal(out) out.writeLong(avgSize) + out.writeInt(hugeBlockSizes.size) + hugeBlockSizes.foreach { kv => + out.writeInt(kv._1) + out.writeByte(kv._2) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -163,6 +180,14 @@ private[spark] class HighlyCompressedMapStatus private ( emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() + val count = in.readInt() + val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]() + (0 until count).foreach { _ => + val block = in.readInt() + val size = in.readByte() + hugeBlockSizesArray += Tuple2(block, size) + } + hugeBlockSizes = hugeBlockSizesArray.toMap } } @@ -178,11 +203,21 @@ private[spark] object HighlyCompressedMapStatus { // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length + val threshold = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) + val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]() while (i < totalNumBlocks) { - var size = uncompressedSizes(i) + val size = uncompressedSizes(i) if (size > 0) { numNonEmptyBlocks += 1 - totalSize += size + // Huge blocks are not included in the calculation for average size, thus size for smaller + // blocks is more accurate. + if (size < threshold) { + totalSize += size + } else { + hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i))) + } } else { emptyBlocks.add(i) } @@ -195,6 +230,7 @@ private[spark] object HighlyCompressedMapStatus { } emptyBlocks.trim() emptyBlocks.runOptimize() - new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) + new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, + hugeBlockSizesArray.toMap) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 1181371ab425..f4b0ab10155a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -97,7 +97,7 @@ private[spark] class Pool( } override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + val sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] val sortedSchedulableQueue = schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 08e05ae0c095..26a6a3effc9a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{InputStream, IOException} +import java.io.{EOFException, InputStream, IOException} import scala.io.Source @@ -107,6 +107,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } } catch { + case _: EOFException if maybeTruncated => case ioe: IOException => throw ioe case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index db4d9efa2270..1b44d0aee319 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet -import org.apache.spark.ShuffleDependency +import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage( parents: List[Stage], firstJobId: Int, callSite: CallSite, - val shuffleDep: ShuffleDependency[_, _, _]) + val shuffleDep: ShuffleDependency[_, _, _], + mapOutputTrackerMaster: MapOutputTrackerMaster) extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { private[this] var _mapStageJobs: List[ActiveJob] = Nil - private[this] var _numAvailableOutputs: Int = 0 - /** * Partitions that either haven't yet been computed, or that were computed on an executor * that has since been lost, so should be re-computed. This variable is used by the @@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage( */ val pendingPartitions = new HashSet[Int] - /** - * List of [[MapStatus]] for each partition. The index of the array is the map partition id, - * and each value in the array is the list of possible [[MapStatus]] for a partition - * (a single task might run multiple times). - */ - private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - override def toString: String = "ShuffleMapStage " + id /** @@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage( /** * Number of partitions that have shuffle outputs. * When this reaches [[numPartitions]], this map stage is ready. - * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. */ - def numAvailableOutputs: Int = _numAvailableOutputs + def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId) /** * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. - * This should be the same as `outputLocs.contains(Nil)`. */ - def isAvailable: Boolean = _numAvailableOutputs == numPartitions + def isAvailable: Boolean = numAvailableOutputs == numPartitions /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { - val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") - missing - } - - def addOutputLoc(partition: Int, status: MapStatus): Unit = { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - _numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - _numAvailableOutputs -= 1 - } - } - - /** - * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned - * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, - * that position is filled with null. - */ - def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = { - outputLocs.map(_.headOption.orNull) - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String): Unit = { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - _numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, _numAvailableOutputs, numPartitions, isAvailable)) - } + mapOutputTrackerMaster + .findMissingPartitions(shuffleDep.shuffleId) + .getOrElse(0 until numPartitions) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index bc2e53071668..b76e560669d5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -52,6 +52,9 @@ case class SparkListenerTaskStart(stageId: Int, stageAttemptId: Int, taskInfo: T @DeveloperApi case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerSpeculativeTaskSubmitted(stageId: Int) extends SparkListenerEvent + @DeveloperApi case class SparkListenerTaskEnd( stageId: Int, @@ -160,9 +163,9 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent /** * An internal class that describes the metadata of an event log. - * This event is not meant to be posted to listeners downstream. */ -private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** * Interface for creating history listeners defined in other modules like SQL, which are used to @@ -290,6 +293,11 @@ private[spark] trait SparkListenerInterface { */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit + /** + * Called when a speculative task is submitted + */ + def onSpeculativeTaskSubmitted(speculativeTask: SparkListenerSpeculativeTaskSubmitted): Unit + /** * Called when other events like SQL-specific events are posted. */ @@ -354,5 +362,8 @@ abstract class SparkListener extends SparkListenerInterface { override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { } + override def onSpeculativeTaskSubmitted( + speculativeTask: SparkListenerSpeculativeTaskSubmitted): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 3ff363321e8c..056c0cbded43 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -71,7 +71,8 @@ private[spark] trait SparkListenerBus listener.onNodeUnblacklisted(nodeUnblacklisted) case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) - case logStart: SparkListenerLogStart => // ignore event log metadata + case speculativeTaskSubmitted: SparkListenerSpeculativeTaskSubmitted => + listener.onSpeculativeTaskSubmitted(speculativeTaskSubmitted) case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c337b992c84..7767ef1803a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -115,26 +115,33 @@ private[spark] abstract class Task[T]( case t: Throwable => e.addSuppressed(t) } + context.markTaskCompleted(Some(e)) throw e } finally { - // Call the task completion callbacks. - context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) - // Notify any tasks waiting for execution memory to be freed to wake up and try to - // acquire memory again. This makes impossible the scenario where a task sleeps forever - // because there are no other tasks left to notify it. Since this is safe to do but may - // not be strictly necessary, we should revisit whether we can remove this in the future. - val memoryManager = SparkEnv.get.memoryManager - memoryManager.synchronized { memoryManager.notifyAll() } - } + // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second + // one is no-op. + context.markTaskCompleted(None) } finally { - // Though we unset the ThreadLocal here, the context member variable itself is still queried - // directly in the TaskRunner to check for FetchFailedExceptions. - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask( + MemoryMode.OFF_HEAP) + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the + // future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } + } + } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still + // queried directly in the TaskRunner to check for FetchFailedExceptions. + TaskContext.unset() + } } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 366b92c5f2ad..836769e1723d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -60,7 +60,7 @@ private[spark] class DirectTaskResult[T]( val numUpdates = in.readInt if (numUpdates == 0) { - accumUpdates = Seq() + accumUpdates = Seq.empty } else { val _accumUpdates = new ArrayBuffer[AccumulatorV2[_, _]] for (i <- 0 until numUpdates) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 3de7d1f7de22..90644fea23ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -89,6 +89,11 @@ private[spark] trait TaskScheduler { */ def executorLost(executorId: String, reason: ExecutorLossReason): Unit + /** + * Process a removed worker + */ + def workerRemoved(workerId: String, host: String, message: String): Unit + /** * Get an application's attempt ID associated with the job. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1b6bc9139f9c..0c11806b3981 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -32,7 +32,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} @@ -51,29 +50,21 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class TaskSchedulerImpl private[scheduler]( +private[spark] class TaskSchedulerImpl( val sc: SparkContext, val maxTaskFailures: Int, - private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) extends TaskScheduler with Logging { import TaskSchedulerImpl._ def this(sc: SparkContext) = { - this( - sc, - sc.conf.get(config.MAX_TASK_FAILURES), - TaskSchedulerImpl.maybeCreateBlacklistTracker(sc)) + this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - def this(sc: SparkContext, maxTaskFailures: Int, isLocal: Boolean) = { - this( - sc, - maxTaskFailures, - TaskSchedulerImpl.maybeCreateBlacklistTracker(sc), - isLocal = isLocal) - } + // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // because ExecutorAllocationClient is created after this TaskSchedulerImpl. + private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) val conf = sc.conf @@ -129,7 +120,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( var backend: SchedulerBackend = null - val mapOutputTracker = SparkEnv.get.mapOutputTracker + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO @@ -240,8 +231,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( // 2. The task set manager has been created but no tasks has been scheduled. In this case, // simply abort the stage. tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread, reason = "stage cancelled") + taskIdToExecutorId.get(tid).foreach(execId => + backend.killTask(tid, execId, interruptThread, reason = "Stage cancelled")) } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) @@ -353,7 +344,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val shuffledOffers = shuffleOffers(filteredOffers) // Build a list of tasks to assign to each worker. - val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { @@ -569,6 +560,11 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + override def workerRemoved(workerId: String, host: String, message: String): Unit = { + logInfo(s"Handle removed worker $workerId: $message") + dagScheduler.workerRemoved(workerId, host, message) + } + private def logExecutorLoss( executorId: String, hostPort: String, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index e815b7e0cf6c..233781f3d971 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -61,6 +61,16 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private val blacklistedExecs = new HashSet[String]() private val blacklistedNodes = new HashSet[String]() + private var latestFailureReason: String = null + + /** + * Get the most recent failure reason of this TaskSet. + * @return + */ + def getLatestFailureReason: String = { + latestFailureReason + } + /** * Return true if this executor is blacklisted for the given task. This does *not* * need to return true if the executor is blacklisted for the entire stage, or blacklisted @@ -94,7 +104,9 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private[scheduler] def updateBlacklistForFailedTask( host: String, exec: String, - index: Int): Unit = { + index: Int, + failureReason: String): Unit = { + latestFailureReason = failureReason val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host)) execFailures.updateWithFailure(index, clock.getTimeMillis()) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a41b059fa7de..bb867416a4fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -26,9 +26,9 @@ import scala.math.max import scala.util.control.NonFatal import org.apache.spark._ +import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ -import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap @@ -56,6 +56,10 @@ private[spark] class TaskSetManager( private val conf = sched.sc.conf + // SPARK-21563 make a copy of the jars/files so they are consistent across the TaskSet + private val addedJars = HashMap[String, Long](sched.sc.addedJars.toSeq: _*) + private val addedFiles = HashMap[String, Long](sched.sc.addedFiles.toSeq: _*) + // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) @@ -198,7 +202,7 @@ private[spark] class TaskSetManager( private[scheduler] var emittedTaskSizeWarning = false /** Add a task to all the pending-task lists that it should be on. */ - private def addPendingTask(index: Int) { + private[spark] def addPendingTask(index: Int) { for (loc <- tasks(index).preferredLocations) { loc match { case e: ExecutorCacheTaskLocation => @@ -502,8 +506,8 @@ private[spark] class TaskSetManager( execId, taskName, index, - sched.sc.addedFiles, - sched.sc.addedJars, + addedFiles, + addedJars, task.localProperties, serializedTask) } @@ -666,9 +670,14 @@ private[spark] class TaskSetManager( } if (blacklistedEverywhere) { val partition = tasks(indexInTaskSet).partitionId - abort(s"Aborting $taskSet because task $indexInTaskSet (partition $partition) " + - s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior " + - s"can be configured via spark.blacklist.*.") + abort(s""" + |Aborting $taskSet because task $indexInTaskSet (partition $partition) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${taskSetBlacklist.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) } } } @@ -774,6 +783,12 @@ private[spark] class TaskSetManager( tasksSuccessful += 1 } isZombie = true + + if (fetchFailed.bmAddress != null) { + blacklistTracker.foreach(_.updateBlacklistForFetchFailure( + fetchFailed.bmAddress.host, fetchFailed.bmAddress.executorId)) + } + None case ef: ExceptionFailure => @@ -826,19 +841,10 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) - if (successful(index)) { - logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" + - s" be re-executed (either because the task failed with a shuffle data fetch failure," + - s" so the previous stage needs to be re-run, or because a different copy of the task" + - s" has already succeeded).") - } else { - addPendingTask(index) - } - if (!isZombie && reason.countTowardsTaskFailures) { - taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( - info.host, info.executorId, index)) assert (null != failureReason) + taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( + info.host, info.executorId, index, failureReason)) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { logError("Task %d in stage %s failed %d times; aborting job".format( @@ -848,6 +854,16 @@ private[spark] class TaskSetManager( return } } + + if (successful(index)) { + logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" + + s" be re-executed (either because the task failed with a shuffle data fetch failure," + + s" so the previous stage needs to be re-run, or because a different copy of the task" + + s" has already succeeded).") + } else { + addPendingTask(index) + } + maybeFinishTaskSet() } @@ -884,7 +900,7 @@ private[spark] class TaskSetManager( override def removeSchedulable(schedulable: Schedulable) {} override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]() + val sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]() sortedTaskSetQueue += this sortedTaskSetQueue } @@ -941,7 +957,7 @@ private[spark] class TaskSetManager( if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { val time = clock.getTimeMillis() - var medianDuration = successfulTaskDurations.median + val medianDuration = successfulTaskDurations.median val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation) // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. @@ -955,6 +971,7 @@ private[spark] class TaskSetManager( "Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms" .format(index, taskSet.id, info.host, threshold)) speculatableTasks += index + sched.dagScheduler.speculativeTaskSubmitted(tasks(index)) foundTasks = true } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 6b49bd699a13..5d65731dfc30 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -32,7 +32,8 @@ private[spark] object CoarseGrainedClusterMessages { case class SparkAppConfig( sparkProperties: Seq[(String, String)], - ioEncryptionKey: Option[Array[Byte]]) + ioEncryptionKey: Option[Array[Byte]], + hadoopDelegationCreds: Option[Array[Byte]]) extends CoarseGrainedClusterMessage case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage @@ -85,6 +86,9 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) extends CoarseGrainedClusterMessage + case class RemoveWorker(workerId: String, host: String, message: String) + extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index dc82bb770472..424e43b25c77 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -23,9 +23,12 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future -import scala.concurrent.duration.Duration + +import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -43,8 +46,8 @@ import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils} */ private[spark] class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) - extends ExecutorAllocationClient with SchedulerBackend with Logging -{ + extends ExecutorAllocationClient with SchedulerBackend with Logging { + // Use an atomic variable to track total number of cores in the cluster for simplicity and speed protected val totalCoreCount = new AtomicInteger(0) // Total number of executors that are currently registered @@ -96,6 +99,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 + // hadoop token manager used by some sub-classes (e.g. Mesos) + def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] = None + + // Hadoop delegation tokens to be sent to the executors. + val hadoopDelegationCreds: Option[Array[Byte]] = getHadoopDelegationCreds() + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -219,9 +228,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) + case RemoveWorker(workerId, host, message) => + removeWorker(workerId, host, message) + context.reply(true) + case RetrieveSparkAppConfig => - val reply = SparkAppConfig(sparkProperties, - SparkEnv.get.securityManager.getIOEncryptionKey()) + val reply = SparkAppConfig( + sparkProperties, + SparkEnv.get.securityManager.getIOEncryptionKey(), + hadoopDelegationCreds) context.reply(reply) } @@ -231,8 +246,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { // Filter out executors under killing val activeExecutors = executorDataMap.filterKeys(executorIsAlive) - val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + val workOffers = activeExecutors.map { + case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -331,6 +347,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + // Remove a lost worker from the cluster + private def removeWorker(workerId: String, host: String, message: String): Unit = { + logDebug(s"Asked to remove worker $workerId with reason $message") + scheduler.workerRemoved(workerId, host, message) + } + /** * Stop making resource offers for the given executor. The executor is marked as lost with * the loss reason still pending. @@ -416,11 +438,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * be called in the yarn-client mode when AM re-registers after a failure. * */ protected def reset(): Unit = { - val executors = synchronized { + val executors: Set[String] = synchronized { requestedTotalExecutors = 0 numPendingExecutors = 0 executorsPendingToRemove.clear() - Set() ++ executorDataMap.keys + executorDataMap.keys.toSet } // Remove all the lingering executors that should be removed but not yet. The reason might be @@ -449,9 +471,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp */ protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { // Only log the failure since we don't care about the result. - driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).onFailure { case t => - logError(t.getMessage, t) - }(ThreadUtils.sameThread) + driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).failed.foreach(t => + logError(t.getMessage, t))(ThreadUtils.sameThread) + } + + protected def removeWorker(workerId: String, host: String, message: String): Unit = { + driverEndpoint.ask[Boolean](RemoveWorker(workerId, host, message)).failed.foreach(t => + logError(t.getMessage, t))(ThreadUtils.sameThread) } def sufficientResourcesRegistered(): Boolean = true @@ -659,6 +685,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint.send(KillExecutorsOnHost(host)) true } + + protected def getHadoopDelegationCreds(): Option[Array[Byte]] = { + if (UserGroupInformation.isSecurityEnabled && hadoopDelegationTokenManager.isDefined) { + hadoopDelegationTokenManager.map { manager => + val creds = UserGroupInformation.getCurrentUser.getCredentials + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + manager.obtainDelegationTokens(hadoopConf, creds) + SparkHadoopUtil.get.serialize(creds) + } + } else { + None + } + } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 0529fe9eed4d..a4e2a7434128 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -58,7 +58,13 @@ private[spark] class StandaloneSchedulerBackend( override def start() { super.start() - launcherBackend.connect() + + // SPARK-21159. The scheduler backend should only try to connect to the launcher when in client + // mode. In cluster mode, the code that submits the application to the Master needs to connect + // to the launcher instead. + if (sc.deployMode == "client") { + launcherBackend.connect() + } // The endpoint for executors to talk to us val driverUrl = RpcEndpointAddress( @@ -161,6 +167,11 @@ private[spark] class StandaloneSchedulerBackend( removeExecutor(fullId.split("/")(1), reason) } + override def workerRemoved(workerId: String, host: String, message: String): Unit = { + logInfo("Worker %s removed: %s".format(workerId, message)) + removeWorker(workerId, host, message) + } + override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 78dabb42ac9d..00621976b77f 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.security -import java.io.{EOFException, InputStream, OutputStream} +import java.io.{InputStream, OutputStream} import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e15166d11c24..58483c9577d2 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -175,6 +175,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(None.getClass) kryo.register(Nil.getClass) kryo.register(Utils.classForName("scala.collection.immutable.$colon$colon")) + kryo.register(Utils.classForName("scala.collection.immutable.Map$EmptyMap$")) kryo.register(classOf[ArrayBuffer[Any]]) kryo.setClassLoader(classLoader) @@ -500,8 +501,8 @@ private class JavaIterableWrapperSerializer private object JavaIterableWrapperSerializer extends Logging { // The class returned by JavaConverters.asJava // (scala.collection.convert.Wrappers$IterableWrapper). - val wrapperClass = - scala.collection.convert.WrapAsJava.asJavaIterable(Seq(1)).getClass + import scala.collection.JavaConverters._ + val wrapperClass = Seq(1).asJava.getClass // Get the underlying method so we can use it to get the Scala collection for serialization. private val underlyingMethodOpt = { diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index bb7ed8709ba8..311383e7ea2b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -41,6 +41,10 @@ private[spark] class SerializerManager( private[this] val kryoSerializer = new KryoSerializer(conf) + def setDefaultClassLoader(classLoader: ClassLoader): Unit = { + kryoSerializer.setDefaultClassLoader(classLoader) + } + private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]] private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = { val primitiveClassTags = Set[ClassTag[_]]( diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index ba3e0e395e95..c8d146030093 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator @@ -51,6 +51,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 15540485170d..94a3a78e9416 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle import java.io._ +import java.nio.file.Files import com.google.common.io.ByteStreams @@ -141,7 +142,8 @@ private[spark] class IndexShuffleBlockResolver( val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) try { - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + val out = new DataOutputStream( + new BufferedOutputStream(Files.newOutputStream(indexTmp.toPath))) Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L @@ -196,7 +198,7 @@ private[spark] class IndexShuffleBlockResolver( // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - val in = new DataInputStream(new FileInputStream(indexFile)) + val in = new DataInputStream(Files.newInputStream(indexFile.toPath)) try { ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala index 01f2a18122e6..eb5cc1b9a3bd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.status.api.v1 import javax.ws.rs.{GET, Produces} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 1818935392eb..4a4ed954d689 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -47,6 +47,7 @@ private[v1] class AllStagesResource(ui: SparkUI) { listener.stageIdToData.get((stageInfo.stageId, stageInfo.attemptId)) } } yield { + stageUiData.lastUpdateTime = ui.lastUpdateTime AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, includeDetails = false) } } @@ -69,7 +70,8 @@ private[v1] object AllStagesResource { } val taskData = if (includeDetails) { - Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } ) + Some(stageUiData.taskData.map { case (k, v) => + k -> convertTaskData(v, stageUiData.lastUpdateTime) }) } else { None } @@ -136,13 +138,13 @@ private[v1] object AllStagesResource { } } - def convertTaskData(uiData: TaskUIData): TaskData = { + def convertTaskData(uiData: TaskUIData, lastUpdateTime: Option[Long]): TaskData = { new TaskData( taskId = uiData.taskInfo.taskId, index = uiData.taskInfo.index, attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), - duration = uiData.taskDuration, + duration = uiData.taskDuration(lastUpdateTime), executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, status = uiData.taskInfo.status, @@ -200,6 +202,7 @@ private[v1] object AllStagesResource { readBytes = submetricQuantiles(_.totalBytesRead), readRecords = submetricQuantiles(_.recordsRead), remoteBytesRead = submetricQuantiles(_.remoteBytesRead), + remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), localBlocksFetched = submetricQuantiles(_.localBlocksFetched), totalBlocksFetched = submetricQuantiles(_.totalBlocksFetched), @@ -281,6 +284,7 @@ private[v1] object AllStagesResource { localBlocksFetched = internal.localBlocksFetched, fetchWaitTime = internal.fetchWaitTime, remoteBytesRead = internal.remoteBytesRead, + remoteBytesReadToDisk = internal.remoteBytesReadToDisk, localBytesRead = internal.localBytesRead, recordsRead = internal.recordsRead ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index a0239266d875..f039744e7f67 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -90,7 +90,8 @@ private[spark] object ApplicationsListResource { }, lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, - completed = internalAttemptInfo.completed + completed = internalAttemptInfo.completed, + appSparkVersion = internalAttemptInfo.appSparkVersion ) } ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala index ab5388159418..2f3b5e984002 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.status.api.v1 import javax.ws.rs.{GET, Produces} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index 3e6d2942d0fb..f15073bccced 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -35,6 +35,7 @@ private[v1] class OneStageResource(ui: SparkUI) { def stageData(@PathParam("stageId") stageId: Int): Seq[StageData] = { withStage(stageId) { stageAttempts => stageAttempts.map { stage => + stage.ui.lastUpdateTime = ui.lastUpdateTime AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, includeDetails = true) } @@ -47,6 +48,7 @@ private[v1] class OneStageResource(ui: SparkUI) { @PathParam("stageId") stageId: Int, @PathParam("stageAttemptId") stageAttemptId: Int): StageData = { withStageAttempt(stageId, stageAttemptId) { stage => + stage.ui.lastUpdateTime = ui.lastUpdateTime AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, includeDetails = true) } @@ -81,7 +83,8 @@ private[v1] class OneStageResource(ui: SparkUI) { @DefaultValue("20") @QueryParam("length") length: Int, @DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting): Seq[TaskData] = { withStageAttempt(stageId, stageAttemptId) { stage => - val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq + val tasks = stage.ui.taskData.values + .map{ AllStagesResource.convertTaskData(_, ui.lastUpdateTime)}.toIndexedSeq .sorted(OneStageResource.ordering(sortBy)) tasks.slice(offset, offset + length) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 56d8e51732ff..31659b25db31 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -20,6 +20,8 @@ import java.util.Date import scala.collection.Map +import com.fasterxml.jackson.annotation.JsonIgnoreProperties + import org.apache.spark.JobExecutionStatus class ApplicationInfo private[spark]( @@ -31,6 +33,9 @@ class ApplicationInfo private[spark]( val memoryPerExecutorMB: Option[Int], val attempts: Seq[ApplicationAttemptInfo]) +@JsonIgnoreProperties( + value = Array("startTimeEpoch", "endTimeEpoch", "lastUpdatedEpoch"), + allowGetters = true) class ApplicationAttemptInfo private[spark]( val attemptId: Option[String], val startTime: Date, @@ -38,10 +43,15 @@ class ApplicationAttemptInfo private[spark]( val lastUpdated: Date, val duration: Long, val sparkUser: String, - val completed: Boolean = false) { - def getStartTimeEpoch: Long = startTime.getTime - def getEndTimeEpoch: Long = endTime.getTime - def getLastUpdatedEpoch: Long = lastUpdated.getTime + val completed: Boolean = false, + val appSparkVersion: String) { + + def getStartTimeEpoch: Long = startTime.getTime + + def getEndTimeEpoch: Long = endTime.getTime + + def getLastUpdatedEpoch: Long = lastUpdated.getTime + } class ExecutorStageSummary private[spark]( @@ -207,6 +217,7 @@ class ShuffleReadMetrics private[spark]( val localBlocksFetched: Long, val fetchWaitTime: Long, val remoteBytesRead: Long, + val remoteBytesReadToDisk: Long, val localBytesRead: Long, val recordsRead: Long) @@ -248,6 +259,7 @@ class ShuffleReadMetricDistributions private[spark]( val localBlocksFetched: IndexedSeq[Double], val fetchWaitTime: IndexedSeq[Double], val remoteBytesRead: IndexedSeq[Double], + val remoteBytesReadToDisk: IndexedSeq[Double], val totalBlocksFetched: IndexedSeq[Double]) class ShuffleWriteMetricDistributions private[spark]( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 3db59837fbeb..219a0e799cc7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging { /** * Release a lock on the given block. + * In case a TaskContext is not propagated properly to all child threads for the task, we fail to + * get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock. + * + * See SPARK-18406 for more discussion of this issue. */ - def unlock(blockId: BlockId): Unit = synchronized { - logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId") + def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized { + val taskId = taskAttemptId.getOrElse(currentTaskAttemptId) + logTrace(s"Task $taskId releasing lock for $blockId") val info = get(blockId).getOrElse { throw new IllegalStateException(s"Block $blockId not found") } if (info.writerTask != BlockInfo.NO_WRITER) { info.writerTask = BlockInfo.NO_WRITER - writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) + writeLocksByTask.removeBinding(taskId, blockId) } else { assert(info.readerCount > 0, s"Block $blockId is not locked for reading") info.readerCount -= 1 - val countsForTask = readLocksByTask(currentTaskAttemptId) + val countsForTask = readLocksByTask(taskId) val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1 assert(newPinCountForTask >= 0, - s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it") + s"Task $taskId release lock on block $blockId more times than it acquired it") } notifyAll() } @@ -336,15 +341,11 @@ private[storage] class BlockInfoManager extends Logging { * * @return the ids of blocks whose pins were released */ - def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = { + def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = synchronized { val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() - val readLocks = synchronized { - readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) - } - val writeLocks = synchronized { - writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) - } + val readLocks = readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) + val writeLocks = writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) for (blockId <- writeLocks) { infos.get(blockId).foreach { info => @@ -353,21 +354,19 @@ private[storage] class BlockInfoManager extends Logging { } blocksWithReleasedLocks += blockId } + readLocks.entrySet().iterator().asScala.foreach { entry => val blockId = entry.getElement val lockCount = entry.getCount blocksWithReleasedLocks += blockId - synchronized { - get(blockId).foreach { info => - info.readerCount -= lockCount - assert(info.readerCount >= 0) - } + get(blockId).foreach { info => + info.readerCount -= lockCount + assert(info.readerCount >= 0) } } - synchronized { - notifyAll() - } + notifyAll() + blocksWithReleasedLocks } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3219969bcd06..a98083df5bd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -23,25 +23,25 @@ import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import com.google.common.io.ByteStreams +import com.codahale.metrics.{MetricRegistry, MetricSet} import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} +import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -177,7 +177,8 @@ private[spark] class BlockManager( // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) + new ExternalShuffleClient(transConf, securityManager, + securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) } else { blockTransferService } @@ -250,6 +251,16 @@ private[spark] class BlockManager( logInfo(s"Initialized BlockManager: $blockManagerId") } + def shuffleMetricsSource: Source = { + import BlockManager._ + + if (externalShuffleServiceEnabled) { + new ShuffleMetricsSource("ExternalShuffle", shuffleClient.shuffleMetrics()) + } else { + new ShuffleMetricsSource("NettyBlockTransfer", shuffleClient.shuffleMetrics()) + } + } + private def registerWithExternalShuffleServer() { logInfo("Registering executor with local external shuffle service.") val shuffleConfig = new ExecutorShuffleInfo( @@ -257,7 +268,7 @@ private[spark] class BlockManager( diskBlockManager.subDirsPerLocalDir, shuffleManager.getClass.getName) - val MAX_ATTEMPTS = 3 + val MAX_ATTEMPTS = conf.get(config.SHUFFLE_REGISTRATION_MAX_ATTEMPTS) val SLEEP_TIME_SECS = 5 for (i <- 1 to MAX_ATTEMPTS) { @@ -337,7 +348,7 @@ private[spark] class BlockManager( val task = asyncReregisterTask if (task != null) { try { - Await.ready(task, Duration.Inf) + ThreadUtils.awaitReady(task, Duration.Inf) } catch { case NonFatal(t) => throw new Exception("Error occurred while waiting for async. reregistration", t) @@ -504,6 +515,7 @@ private[spark] class BlockManager( case Some(info) => val level = info.level logDebug(s"Level for block $blockId is $level") + val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId()) if (level.useMemory && memoryStore.contains(blockId)) { val iter: Iterator[Any] = if (level.deserialized) { memoryStore.getValues(blockId).get @@ -511,7 +523,12 @@ private[spark] class BlockManager( serializerManager.dataDeserializeStream( blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag) } - val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) + // We need to capture the current taskId in case the iterator completion is triggered + // from a different thread which does not have TaskContext set; see SPARK-18406 for + // discussion. + val ci = CompletionIterator[Any, Iterator[Any]](iter, { + releaseLock(blockId, taskAttemptId) + }) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { val diskData = diskStore.getBytes(blockId) @@ -528,8 +545,9 @@ private[spark] class BlockManager( serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, - releaseLockAndDispose(blockId, diskData)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, { + releaseLockAndDispose(blockId, diskData, taskAttemptId) + }) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -612,12 +630,19 @@ private[spark] class BlockManager( /** * Return a list of locations for the given block, prioritizing the local machine since - * multiple block managers can share the same host. + * multiple block managers can share the same host, followed by hosts on the same rack. */ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { val locs = Random.shuffle(master.getLocations(blockId)) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } - preferredLocs ++ otherLocs + blockManagerId.topologyInfo match { + case None => preferredLocs ++ otherLocs + case Some(_) => + val (sameRackLocs, differentRackLocs) = otherLocs.partition { + loc => blockManagerId.topologyInfo == loc.topologyInfo + } + preferredLocs ++ sameRackLocs ++ differentRackLocs + } } /** @@ -707,10 +732,13 @@ private[spark] class BlockManager( } /** - * Release a lock on the given block. + * Release a lock on the given block with explicit TID. + * The param `taskAttemptId` should be passed in case we can't get the correct TID from + * TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child + * thread. */ - def releaseLock(blockId: BlockId): Unit = { - blockInfoManager.unlock(blockId) + def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = { + blockInfoManager.unlock(blockId, taskAttemptId) } /** @@ -912,7 +940,7 @@ private[spark] class BlockManager( if (level.replication > 1) { // Wait for asynchronous replication to finish try { - Await.ready(replicationFuture, Duration.Inf) + ThreadUtils.awaitReady(replicationFuture, Duration.Inf) } catch { case NonFatal(t) => throw new Exception("Error occurred while waiting for replication to finish", t) @@ -973,11 +1001,16 @@ private[spark] class BlockManager( logWarning(s"Putting block $blockId failed") } res + } catch { + // Since removeBlockInternal may throw exception, + // we should print exception first to show root cause. + case NonFatal(e) => + logWarning(s"Putting block $blockId failed due to exception $e.") + throw e } finally { // This cleanup is performed in a finally block rather than a `catch` to avoid having to // catch and properly re-throw InterruptedException. if (exceptionWasThrown) { - logWarning(s"Putting block $blockId failed due to an exception") // If an exception was thrown then it's possible that the code in `putBody` has already // notified the master about the availability of this block, so we need to send an update // to remove this block location. @@ -1260,11 +1293,11 @@ private[spark] class BlockManager( val numPeersToReplicateTo = level.replication - 1 val startTime = System.nanoTime - var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas - var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId] + val peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas + val peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId] var numFailures = 0 - val initialPeers = getPeers(false).filterNot(existingReplicas.contains(_)) + val initialPeers = getPeers(false).filterNot(existingReplicas.contains) var peersForReplication = blockReplicationPolicy.prioritize( blockManagerId, @@ -1458,13 +1491,18 @@ private[spark] class BlockManager( } private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = { - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) + if (conf.get(config.TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES)) { + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) + } } } - def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { - blockInfoManager.unlock(blockId) + def releaseLockAndDispose( + blockId: BlockId, + data: BlockData, + taskAttemptId: Option[Long] = None): Unit = { + releaseLock(blockId, taskAttemptId) data.dispose() } @@ -1506,4 +1544,12 @@ private[spark] object BlockManager { } blockManagers.toMap } + + private class ShuffleMetricsSource( + override val sourceName: String, + metricSet: MetricSet) extends Source { + + override val metricRegistry = new MetricRegistry + metricRegistry.registerAll(metricSet) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index c37a3604d28f..2c3da0ee85e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -46,7 +46,7 @@ class BlockManagerId private ( def executorId: String = executorId_ if (null != host_) { - Utils.checkHost(host_, "Expected hostname") + Utils.checkHost(host_) assert (port_ > 0) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index 1ea0d378cbe8..3d3806126676 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -22,7 +22,6 @@ import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.util.io.ChunkedByteBuffer /** * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index ea5d8423a588..8b1dc0ba6356 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -118,10 +118,9 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { val future = driverEndpoint.askSync[Future[Seq[Int]]](RemoveRdd(rddId)) - future.onFailure { - case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) - }(ThreadUtils.sameThread) + future.failed.foreach(e => + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) + )(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) } @@ -130,10 +129,9 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { val future = driverEndpoint.askSync[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) - future.onFailure { - case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) - }(ThreadUtils.sameThread) + future.failed.foreach(e => + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) + )(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) } @@ -143,11 +141,10 @@ class BlockManagerMaster( def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { val future = driverEndpoint.askSync[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) - future.onFailure { - case e: Exception => - logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) - }(ThreadUtils.sameThread) + future.failed.foreach(e => + logWarning(s"Failed to remove broadcast $broadcastId" + + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) + )(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 6f85b9e4d6c7..df0a5f5e229f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,8 +19,8 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.util.Random diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 1aaa42459df6..742cf4fe393f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -85,13 +85,13 @@ class BlockManagerSlaveEndpoint( logDebug(actionMessage) body } - future.onSuccess { case response => - logDebug("Done " + actionMessage + ", response is " + response) + future.foreach { response => + logDebug(s"Done $actionMessage, response is $response") context.reply(response) - logDebug("Sent response: " + response + " to " + context.senderAddress) + logDebug(s"Sent response: $response to ${context.senderAddress}") } - future.onFailure { case t: Throwable => - logError("Error in " + actionMessage, t) + future.failed.foreach { t => + logError(s"Error in $actionMessage", t) context.sendFailure(t) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index eb3ff926372a..a024c83d8d8b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter( /** * Keep track of number of records written and also use this to periodically * output bytes written since the latter is expensive to do for each record. + * And we reset it after every commitAndGet called. */ private var numRecordsWritten = 0 @@ -185,6 +186,7 @@ private[spark] class DiskBlockObjectWriter( // In certain compression codecs, more bytes are written after streams are closed writeMetrics.incBytesWritten(committedPosition - reportedPosition) reportedPosition = committedPosition + numRecordsWritten = 0 fileSegment } else { new FileSegment(file, committedPosition, 0) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index c6656341fcd1..3579acf8d83d 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -21,21 +21,19 @@ import java.io._ import java.nio.ByteBuffer import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel} import java.nio.channels.FileChannel.MapMode -import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ListBuffer -import com.google.common.io.{ByteStreams, Closeables, Files} -import io.netty.channel.FileRegion +import com.google.common.io.Closeables +import io.netty.channel.{DefaultFileRegion, FileRegion} import io.netty.util.AbstractReferenceCounted import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.JavaUtils import org.apache.spark.security.CryptoStreamUtils -import org.apache.spark.util.{ByteBufferInputStream, Utils} +import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer /** @@ -47,6 +45,8 @@ private[spark] class DiskStore( securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", + Int.MaxValue.toString) private val blockSizes = new ConcurrentHashMap[String, Long]() def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) @@ -108,25 +108,7 @@ private[spark] class DiskStore( new EncryptedBlockData(file, blockSize, conf, key) case _ => - val channel = new FileInputStream(file).getChannel() - if (blockSize < minMemoryMapBytes) { - // For small files, directly read rather than memory map. - Utils.tryWithSafeFinally { - val buf = ByteBuffer.allocate(blockSize.toInt) - JavaUtils.readFully(channel, buf) - buf.flip() - new ByteBufferBlockData(new ChunkedByteBuffer(buf), true) - } { - channel.close() - } - } else { - Utils.tryWithSafeFinally { - new ByteBufferBlockData( - new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true) - } { - channel.close() - } - } + new DiskBlockData(minMemoryMapBytes, maxMemoryMapBytes, file, blockSize) } } @@ -165,6 +147,61 @@ private[spark] class DiskStore( } +private class DiskBlockData( + minMemoryMapBytes: Long, + maxMemoryMapBytes: Long, + file: File, + blockSize: Long) extends BlockData { + + override def toInputStream(): InputStream = new FileInputStream(file) + + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * Please see `ManagedBuffer.convertToNetty()` for more details. + */ + override def toNetty(): AnyRef = new DefaultFileRegion(file, 0, size) + + override def toChunkedByteBuffer(allocator: (Int) => ByteBuffer): ChunkedByteBuffer = { + Utils.tryWithResource(open()) { channel => + var remaining = blockSize + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, maxMemoryMapBytes) + val chunk = allocator(chunkSize.toInt) + remaining -= chunkSize + JavaUtils.readFully(channel, chunk) + chunk.flip() + chunks += chunk + } + new ChunkedByteBuffer(chunks.toArray) + } + } + + override def toByteBuffer(): ByteBuffer = { + require(blockSize < maxMemoryMapBytes, + s"can't create a byte buffer of size $blockSize" + + s" since it exceeds ${Utils.bytesToString(maxMemoryMapBytes)}.") + Utils.tryWithResource(open()) { channel => + if (blockSize < minMemoryMapBytes) { + // For small files, directly read rather than memory map. + val buf = ByteBuffer.allocate(blockSize.toInt) + JavaUtils.readFully(channel, buf) + buf.flip() + buf + } else { + channel.map(MapMode.READ_ONLY, 0, file.length) + } + } + } + + override def size: Long = blockSize + + override def dispose(): Unit = {} + + private def open() = new FileInputStream(file).getChannel +} + private class EncryptedBlockData( file: File, blockSize: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f8906117638b..2d176b62f8b3 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,18 +17,18 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -52,6 +52,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point + * for a given remote host:port. + * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. */ private[spark] @@ -63,8 +66,10 @@ final class ShuffleBlockFetcherIterator( streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with Logging { + extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -108,12 +113,21 @@ final class ShuffleBlockFetcherIterator( */ private[this] val fetchRequests = new Queue[FetchRequest] + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L /** Current number of requests in flight */ private[this] var reqsInFlight = 0 + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + /** * The blocks that can't be decompressed successfully, it is used to guarantee that we retry * at most once for those corrupted blocks. @@ -129,6 +143,13 @@ final class ShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[File]() + initialize() // Decrements the buffer reference count. @@ -141,6 +162,19 @@ final class ShuffleBlockFetcherIterator( currentResult = null } + override def createTempShuffleFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempShuffleFileToClean(file: File): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + /** * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ @@ -157,12 +191,20 @@ final class ShuffleBlockFetcherIterator( case SuccessFetchResult(_, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() case _ => } } + shuffleFilesSet.foreach { file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()) + } + } } private[this] def sendRequest(req: FetchRequest) { @@ -175,33 +217,42 @@ final class ShuffleBlockFetcherIterator( val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) - val address = req.address - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { - // Only add the buffer to results queue if the iterator is not zombie, - // i.e. cleanup() has not been called yet. - ShuffleBlockFetcherIterator.this.synchronized { - if (!isZombie) { - // Increment the ref count because we need to pass this to a different thread. - // This needs to be released after use. - buf.retain() - remainingBlocks -= blockId - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, - remainingBlocks.isEmpty)) - logDebug("remainingBlocks: " + remainingBlocks) - } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + ShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) } - logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } - override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { - logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), address, e)) - } + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) } - ) + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, this) + } else { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, null) + } } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { @@ -209,7 +260,8 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize + + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -238,11 +290,13 @@ final class ShuffleBlockFetcherIterator( } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } - if (curRequestSize >= targetRequestSize) { + if (curRequestSize >= targetRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) + logDebug(s"Creating fetch request of $curRequestSize at $address " + + s"with ${curBlocks.size} blocks") curBlocks = new ArrayBuffer[(BlockId, Long)] - logDebug(s"Creating fetch request of $curRequestSize at $address") curRequestSize = 0 } } @@ -336,7 +390,11 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } shuffleMetrics.incRemoteBlocksFetched(1) } bytesInFlight -= size @@ -401,12 +459,57 @@ final class ShuffleBlockFetcherIterator( } private def fetchUpToMaxBytes(): Unit = { - // Send fetch requests up to maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || - (reqsInFlight + 1 <= maxReqsInFlight && - bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) { - sendRequest(fetchRequests.dequeue()) + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while (isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { + val request = defReqQueue.dequeue() + logDebug(s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + sendRequest(request) + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress } } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 90e3af2d0ec7..651e9c7b2ab6 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -29,9 +29,10 @@ import com.google.common.io.ByteStreams import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{UNROLL_MEMORY_CHECK_PERIOD, UNROLL_MEMORY_GROWTH_FACTOR} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} -import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel, StreamBlockId} +import org.apache.spark.storage._ import org.apache.spark.unsafe.Platform import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -190,11 +191,11 @@ private[spark] class MemoryStore( // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory - val memoryCheckPeriod = 16 + val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size - val memoryGrowthFactor = 1.5 + val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Keep track of unroll memory used by this particular block / putIterator() operation var unrollMemoryUsedByThisBlock = 0L // Underlying vector for unrolling the block @@ -325,6 +326,12 @@ private[spark] class MemoryStore( // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true + // Number of elements unrolled so far + var elementsUnrolled = 0L + // How often to check whether we need to request more memory + val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) + // Memory to request as a multiple of current bbos size + val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold // Keep track of unroll memory used by this particular block / putIterator() operation @@ -359,7 +366,7 @@ private[spark] class MemoryStore( def reserveAdditionalMemoryIfNecessary(): Unit = { if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + val amountToRequest = (bbos.size * memoryGrowthFactor - unrollMemoryUsedByThisBlock).toLong keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest @@ -370,7 +377,10 @@ private[spark] class MemoryStore( // Unroll this block safely, checking whether we have exceeded our threshold while (values.hasNext && keepUnrolling) { serializationStream.writeObject(values.next())(classTag) - reserveAdditionalMemoryIfNecessary() + elementsUnrolled += 1 + if (elementsUnrolled % memoryCheckPeriod == 0) { + reserveAdditionalMemoryIfNecessary() + } } // Make sure that we have enough memory to store the block. By this point, it is possible that @@ -534,20 +544,38 @@ private[spark] class MemoryStore( } if (freedMemory >= space) { - logInfo(s"${selectedBlocks.size} blocks selected for dropping " + - s"(${Utils.bytesToString(freedMemory)} bytes)") - for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one task should be dropping - // blocks and removing entries. However the check is still here for - // future safety. - if (entry != null) { - dropBlock(blockId, entry) + var lastSuccessfulBlock = -1 + try { + logInfo(s"${selectedBlocks.size} blocks selected for dropping " + + s"(${Utils.bytesToString(freedMemory)} bytes)") + (0 until selectedBlocks.size).foreach { idx => + val blockId = selectedBlocks(idx) + val entry = entries.synchronized { + entries.get(blockId) + } + // This should never be null as only one task should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + dropBlock(blockId, entry) + afterDropAction(blockId) + } + lastSuccessfulBlock = idx + } + logInfo(s"After dropping ${selectedBlocks.size} blocks, " + + s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}") + freedMemory + } finally { + // like BlockManager.doPut, we use a finally rather than a catch to avoid having to deal + // with InterruptedException + if (lastSuccessfulBlock != selectedBlocks.size - 1) { + // the blocks we didn't process successfully are still locked, so we have to unlock them + (lastSuccessfulBlock + 1 until selectedBlocks.size).foreach { idx => + val blockId = selectedBlocks(idx) + blockInfoManager.unlock(blockId) + } } } - logInfo(s"After dropping ${selectedBlocks.size} blocks, " + - s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}") - freedMemory } else { blockId.foreach { id => logInfo(s"Will not store $id") @@ -560,6 +588,9 @@ private[spark] class MemoryStore( } } + // hook for testing, so we can simulate a race + protected def afterDropAction(blockId: BlockId): Unit = {} + def contains(blockId: BlockId): Boolean = { entries.synchronized { entries.containsKey(blockId) } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index edf328b5ae53..5ee04dad6ed4 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -21,11 +21,12 @@ import java.net.{URI, URL} import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} -import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.xml.Node +import org.eclipse.jetty.client.HttpClient import org.eclipse.jetty.client.api.Response +import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ @@ -52,7 +53,7 @@ private[spark] object JettyUtils extends Logging { // implicit conversion from many types of functions to jetty Handlers. type Responder[T] = HttpServletRequest => T - class ServletParams[T <% AnyRef](val responder: Responder[T], + class ServletParams[T <: AnyRef](val responder: Responder[T], val contentType: String, val extractFn: T => String = (in: Any) => in.toString) {} @@ -66,7 +67,7 @@ private[spark] object JettyUtils extends Logging { implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] = new ServletParams(responder, "text/plain") - def createServlet[T <% AnyRef]( + def createServlet[T <: AnyRef]( servletParams: ServletParams[T], securityMgr: SecurityManager, conf: SparkConf): HttpServlet = { @@ -111,7 +112,7 @@ private[spark] object JettyUtils extends Logging { } /** Create a context handler that responds to a request with the given path prefix */ - def createServletHandler[T <% AnyRef]( + def createServletHandler[T <: AnyRef]( path: String, servletParams: ServletParams[T], securityMgr: SecurityManager, @@ -192,20 +193,35 @@ private[spark] object JettyUtils extends Logging { } /** Create a handler for proxying request to Workers and Application Drivers */ - def createProxyHandler( - prefix: String, - target: String): ServletContextHandler = { + def createProxyHandler(idToUiAddress: String => Option[String]): ServletContextHandler = { val servlet = new ProxyServlet { override def rewriteTarget(request: HttpServletRequest): String = { - val rewrittenURI = createProxyURI( - prefix, target, request.getRequestURI(), request.getQueryString()) - if (rewrittenURI == null) { - return null - } - if (!validateDestination(rewrittenURI.getHost(), rewrittenURI.getPort())) { - return null + val path = request.getPathInfo + if (path == null) return null + + val prefixTrailingSlashIndex = path.indexOf('/', 1) + val prefix = if (prefixTrailingSlashIndex == -1) { + path + } else { + path.substring(0, prefixTrailingSlashIndex) } - rewrittenURI.toString() + val id = prefix.drop(1) + + // Query master state for id's corresponding UI address + // If that address exists, try to turn it into a valid, target URI string + // Otherwise, return null + idToUiAddress(id) + .map(createProxyURI(prefix, _, path, request.getQueryString)) + .filter(uri => uri != null && validateDestination(uri.getHost, uri.getPort)) + .map(_.toString) + .orNull + } + + override def newHttpClient(): HttpClient = { + // SPARK-21176: Use the Jetty logic to calculate the number of selector threads (#CPUs/2), + // but limit it to 8 max. + val numSelectors = math.max(1, math.min(8, Runtime.getRuntime().availableProcessors() / 2)) + new HttpClient(new HttpClientTransportOverHTTP(numSelectors), null) } override def filterServerResponseHeader( @@ -214,8 +230,8 @@ private[spark] object JettyUtils extends Logging { headerName: String, headerValue: String): String = { if (headerName.equalsIgnoreCase("location")) { - val newHeader = createProxyLocationHeader( - prefix, headerValue, clientRequest, serverResponse.getRequest().getURI()) + val newHeader = createProxyLocationHeader(headerValue, clientRequest, + serverResponse.getRequest().getURI()) if (newHeader != null) { return newHeader } @@ -227,8 +243,8 @@ private[spark] object JettyUtils extends Logging { val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) - contextHandler.setContextPath(prefix) - contextHandler.addServlet(holder, "/") + contextHandler.setContextPath("/proxy") + contextHandler.addServlet(holder, "/*") contextHandler } @@ -426,7 +442,7 @@ private[spark] object JettyUtils extends Logging { val rest = path.substring(prefix.length()) if (!rest.isEmpty()) { - if (!rest.startsWith("/")) { + if (!rest.startsWith("/") && !uri.endsWith("/")) { uri.append("/") } uri.append(rest) @@ -446,14 +462,15 @@ private[spark] object JettyUtils extends Logging { } def createProxyLocationHeader( - prefix: String, headerValue: String, clientRequest: HttpServletRequest, targetUri: URI): String = { val toReplace = targetUri.getScheme() + "://" + targetUri.getAuthority() if (headerValue.startsWith(toReplace)) { - clientRequest.getScheme() + "://" + clientRequest.getHeader("host") + - prefix + headerValue.substring(toReplace.length()) + val id = clientRequest.getPathInfo.substring("/proxy/".length).takeWhile(_ != '/') + val headerPath = headerValue.substring(toReplace.length) + + s"${clientRequest.getScheme}://${clientRequest.getHeader("host")}/proxy/$id$headerPath" } else { null } diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 79974df2603f..65fa38387b9e 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -94,14 +94,16 @@ private[ui] trait PagedTable[T] { val _dataSource = dataSource try { val PageData(totalPages, data) = _dataSource.pageData(page) + val pageNavi = pageNavigation(page, _dataSource.pageSize, totalPages)
- {pageNavigation(page, _dataSource.pageSize, totalPages)} + {pageNavi} {headers} {data.map(row)}
+ {pageNavi}
} catch { case e: IndexOutOfBoundsException => diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index bf4cf79e9faa..6e94073238a5 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -50,6 +50,7 @@ private[spark] class SparkUI private ( val operationGraphListener: RDDOperationGraphListener, var appName: String, val basePath: String, + val lastUpdateTime: Option[Long] = None, val startTime: Long) extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf), conf, basePath, "SparkUI") @@ -60,6 +61,8 @@ private[spark] class SparkUI private ( var appId: String = _ + var appSparkVersion = org.apache.spark.SPARK_VERSION + private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ @@ -84,7 +87,9 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.systemProperties.toMap.getOrElse("user.name", "") + environmentListener.sparkUser + .orElse(environmentListener.systemProperties.toMap.get("user.name")) + .getOrElse("") } def getAppName: String = appName @@ -118,7 +123,8 @@ private[spark] class SparkUI private ( duration = 0, lastUpdated = new Date(startTime), sparkUser = getSparkUser, - completed = false + completed = false, + appSparkVersion = appSparkVersion )) )) } @@ -139,6 +145,7 @@ private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) def appName: String = parent.appName + def appSparkVersion: String = parent.appSparkVersion } private[spark] object SparkUI { @@ -155,13 +162,14 @@ private[spark] object SparkUI { def createLiveUI( sc: SparkContext, conf: SparkConf, - listenerBus: SparkListenerBus, jobProgressListener: JobProgressListener, securityManager: SecurityManager, appName: String, startTime: Long): SparkUI = { - create(Some(sc), conf, listenerBus, securityManager, appName, - jobProgressListener = Some(jobProgressListener), startTime = startTime) + create(Some(sc), conf, + sc.listenerBus.addToStatusQueue, + securityManager, appName, jobProgressListener = Some(jobProgressListener), + startTime = startTime) } def createHistoryUI( @@ -170,9 +178,10 @@ private[spark] object SparkUI { securityManager: SecurityManager, appName: String, basePath: String, + lastUpdateTime: Option[Long], startTime: Long): SparkUI = { - val sparkUI = create( - None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create(None, conf, listenerBus.addListener, securityManager, appName, basePath, + lastUpdateTime = lastUpdateTime, startTime = startTime) val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], Utils.getContextOrSparkClassLoader).asScala @@ -193,16 +202,17 @@ private[spark] object SparkUI { private def create( sc: Option[SparkContext], conf: SparkConf, - listenerBus: SparkListenerBus, + addListenerFn: SparkListenerInterface => Unit, securityManager: SecurityManager, appName: String, basePath: String = "", jobProgressListener: Option[JobProgressListener] = None, + lastUpdateTime: Option[Long] = None, startTime: Long): SparkUI = { val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { val listener = new JobProgressListener(conf) - listenerBus.addListener(listener) + addListenerFn(listener) listener } @@ -212,14 +222,14 @@ private[spark] object SparkUI { val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) - listenerBus.addListener(environmentListener) - listenerBus.addListener(storageStatusListener) - listenerBus.addListener(executorsListener) - listenerBus.addListener(storageListener) - listenerBus.addListener(operationGraphListener) + addListenerFn(environmentListener) + addListenerFn(storageStatusListener) + addListenerFn(executorsListener) + addListenerFn(storageListener) + addListenerFn(operationGraphListener) new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, executorsListener, _jobProgressListener, storageListener, operationGraphListener, - appName, basePath, startTime) + appName, basePath, lastUpdateTime, startTime) } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 79b0d81af52b..ba798df13c95 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import scala.xml._ import scala.xml.transform.{RewriteRule, RuleTransformer} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.internal.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -34,6 +36,8 @@ private[spark] object UIUtils extends Logging { val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" + private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = @@ -228,7 +232,7 @@ private[spark] object UIUtils extends Logging {
{completed}/{total} + { if (failed == 0 && skipped == 0 && started > 0) s"($started running)" } { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } { reasonToNumKilled.toSeq.sortBy(-_._2).map { @@ -527,4 +532,21 @@ private[spark] object UIUtils extends Logging { origHref } } + + /** + * Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks + * + * For more information about XSS testing: + * https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and + * https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001) + */ + def stripXSS(requestParameter: String): String = { + if (requestParameter == null) { + null + } else { + // Remove new lines and single quotes, followed by escaping HTML version 4.0 + StringEscapeUtils.escapeHtml4( + NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, "")) + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index 094953f2f5b5..6229e800957d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -66,11 +66,11 @@ private[spark] object UIWorkloadGenerator { def nextFloat(): Float = new Random().nextFloat() val jobs = Seq[(String, () => Long)]( - ("Count", baseData.count), - ("Cache and Count", baseData.map(x => x).cache().count), - ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), - ("Entirely failed phase", baseData.map(x => throw new Exception).count), - ("Partially failed phase", { + ("Count", () => baseData.count), + ("Cache and Count", () => baseData.map(x => x).cache().count), + ("Single Shuffle", () => baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), + ("Entirely failed phase", () => baseData.map { x => throw new Exception(); 1 }.count), + ("Partially failed phase", () => { baseData.map{x => val probFailure = (4.0 / NUM_PARTITIONS) if (nextFloat() < probFailure) { @@ -79,7 +79,7 @@ private[spark] object UIWorkloadGenerator { 1 }.count }), - ("Partially failed phase (longer tasks)", { + ("Partially failed phase (longer tasks)", () => { baseData.map{x => val probFailure = (4.0 / NUM_PARTITIONS) if (nextFloat() < probFailure) { @@ -89,7 +89,7 @@ private[spark] object UIWorkloadGenerator { 1 }.count }), - ("Job with delays", baseData.map(x => Thread.sleep(100)).count) + ("Job with delays", () => baseData.map(x => Thread.sleep(100)).count) ) val barrier = new Semaphore(-nJobSet * jobs.size + 1) diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 8c18464e6477..61b12aaa32bb 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -34,11 +34,16 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en @DeveloperApi @deprecated("This class will be removed in a future release.", "2.2.0") class EnvironmentListener extends SparkListener { + var sparkUser: Option[String] = None var jvmInformation = Seq[(String, String)]() var sparkProperties = Seq[(String, String)]() var systemProperties = Seq[(String, String)]() var classpathEntries = Seq[(String, String)]() + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + sparkUser = Some(event.sparkUser) + } + override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { val environmentDetails = environmentUpdate.environmentDetails diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 6ce3f511e89c..7b211ea5199c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -28,8 +28,10 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage private val sc = parent.sc + // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val executorId = Option(request.getParameter("executorId")).map { executorId => + val executorId = + Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId => UIUtils.decodeURLParameter(executorId) }.getOrElse { throw new IllegalArgumentException(s"Missing executorId parameter") diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b7cbed468517..d63381c78bc3 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -82,7 +82,7 @@ private[ui] class ExecutorsPage(
++ -
++ + ++ ++ ++ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index aabf6e0c63c0..64a1a292a384 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -11,7 +11,7 @@ * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and + * See the License for the specific language governing permissions and * limitations under the License. */ @@ -19,7 +19,7 @@ package org.apache.spark.ui.exec import scala.collection.mutable.{LinkedHashMap, ListBuffer} -import org.apache.spark.{ExceptionFailure, Resubmitted, SparkConf, SparkContext} +import org.apache.spark.{Resubmitted, SparkConf, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} @@ -131,17 +131,17 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar if (info != null) { val eid = info.executorId val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - taskEnd.reason match { - case Resubmitted => - // Note: For resubmitted tasks, we continue to use the metrics that belong to the - // first attempt of this task. This may not be 100% accurate because the first attempt - // could have failed half-way through. The correct fix would be to keep track of the - // metrics added by each attempt, but this is much more complicated. - return - case _: ExceptionFailure => - taskSummary.tasksFailed += 1 - case _ => - taskSummary.tasksComplete += 1 + // Note: For resubmitted tasks, we continue to use the metrics that belong to the + // first attempt of this task. This may not be 100% accurate because the first attempt + // could have failed half-way through. The correct fix would be to keep track of the + // metrics added by each attempt, but this is much more complicated. + if (taskEnd.reason == Resubmitted) { + return + } + if (info.successful) { + taskSummary.tasksComplete += 1 + } else { + taskSummary.tasksFailed += 1 } if (taskSummary.tasksActive >= 1) { taskSummary.tasksActive -= 1 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 18be0870746e..a7f2caafe04b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -220,18 +220,20 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { jobTag: String, jobs: Seq[JobUIData], killEnabled: Boolean): Seq[Node] = { - val allParameters = request.getParameterMap.asScala.toMap + // stripXSS is called to remove suspicious characters used in XSS attacks + val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id" - val parameterJobPage = request.getParameter(jobTag + ".page") - val parameterJobSortColumn = request.getParameter(jobTag + ".sort") - val parameterJobSortDesc = request.getParameter(jobTag + ".desc") - val parameterJobPageSize = request.getParameter(jobTag + ".pageSize") - val parameterJobPrevPageSize = request.getParameter(jobTag + ".prevPageSize") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page")) + val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort")) + val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc")) + val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize")) + val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize")) val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1) val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn => @@ -239,7 +241,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { }.getOrElse(jobIdTitle) val jobSortDesc = Option(parameterJobSortDesc).map(_.toBoolean).getOrElse( // New jobs should be shown above old jobs by default. - if (jobSortColumn == jobIdTitle) true else false + jobSortColumn == jobIdTitle ) val jobPageSize = Option(parameterJobPageSize).map(_.toInt).getOrElse(100) val jobPrevPageSize = Option(parameterJobPrevPageSize).map(_.toInt).getOrElse(jobPageSize) @@ -629,7 +631,8 @@ private[ui] class JobPagedTable( {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"} - {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, + {UIUtils.makeProgressBar(started = job.numActiveTasks, + completed = job.completedIndices.size, failed = job.numFailedTasks, skipped = job.numSkippedTasks, reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 2b0816e35747..a30c13592947 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -115,7 +115,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { if (sc.isDefined && isFairScheduler) {

{pools.size} Fair Scheduler Pools

++ poolTable.toNodeSeq } else { - Seq[Node]() + Seq.empty[Node] } } if (shouldShowActiveStages) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 3131c4a1eb7d..9fb011a049b7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -187,7 +187,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { val listener = parent.jobProgresslistener listener.synchronized { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val jobId = parameterId.toInt diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 8870187f2219..a18e86ec0a73 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -329,13 +329,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { - val metrics = TaskMetrics.empty val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { logWarning("Task start for unknown stage " + taskStart.stageId) new StageUIData }) stageData.numActiveTasks += 1 - stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo, Some(metrics))) + stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo)) } for ( activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); @@ -375,6 +374,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { case kill: TaskKilled => execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) + case commitDenied: TaskCommitDenied => + execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( + commitDenied.toErrorString, execSummary.reasonToNumKilled.getOrElse( + commitDenied.toErrorString, 0) + 1) case _ => execSummary.failedTasks += 1 } @@ -391,6 +394,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) Some(kill.toErrorString) + case commitDenied: TaskCommitDenied => + stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( + commitDenied.toErrorString, stageData.reasonToNumKilled.getOrElse( + commitDenied.toErrorString, 0) + 1) + Some(commitDenied.toErrorString) case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates stageData.numFailedTasks += 1 Some(e.toErrorString) @@ -405,7 +413,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) } - val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info, None)) + val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info)) taskData.updateTaskInfo(info) taskData.updateTaskMetrics(taskMetrics) taskData.errorMessage = errorMessage @@ -424,10 +432,15 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { jobData.numActiveTasks -= 1 taskEnd.reason match { case Success => + jobData.completedIndices.add((taskEnd.stageId, info.index)) jobData.numCompletedTasks += 1 case kill: TaskKilled => jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) + case commitDenied: TaskCommitDenied => + jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( + commitDenied.toErrorString, jobData.reasonToNumKilled.getOrElse( + commitDenied.toErrorString, 0) + 1) case _ => jobData.numFailedTasks += 1 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 620c54c2dc0a..cc173381879a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all jobs in the given SparkContext. */ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { @@ -40,7 +40,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - val jobId = Option(request.getParameter("id")).map(_.toInt) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) jobId.foreach { id => if (jobProgresslistener.activeJobs.contains(id)) { sc.foreach(_.cancelJob(id)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 8ee70d27cc09..819fe57e14b2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val poolName = Option(request.getParameter("poolname")).map { poolname => + // stripXSS is called first to remove suspicious characters used in XSS attacks + val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname => UIUtils.decodeURLParameter(poolname) }.getOrElse { throw new IllegalArgumentException(s"Missing poolname parameter") @@ -40,7 +41,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { val poolToActiveStages = listener.poolToActiveStages val activeStages = poolToActiveStages.get(poolName) match { case Some(s) => s.values.toSeq - case None => Seq[StageInfo]() + case None => Seq.empty[StageInfo] } val shouldShowActiveStages = activeStages.nonEmpty val activeStagesTable = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 19325a2dc916..4d80308eb0a6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -87,17 +87,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterAttempt = request.getParameter("attempt") + val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt")) require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") - val parameterTaskPage = request.getParameter("task.page") - val parameterTaskSortColumn = request.getParameter("task.sort") - val parameterTaskSortDesc = request.getParameter("task.desc") - val parameterTaskPageSize = request.getParameter("task.pageSize") - val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize") + val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page")) + val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) + val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) + val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize")) val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => @@ -298,6 +299,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { stageData.hasShuffleRead, stageData.hasShuffleWrite, stageData.hasBytesSpilled, + parent.lastUpdateTime, currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, @@ -564,7 +566,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators) {

Accumulators

++ accumulableTable } else Seq() + if (hasAccumulators) {

Accumulators

++ accumulableTable } else Seq.empty val aggMetrics = UIUtils.formatDuration(d)).getOrElse("") + val duration = taskData.taskDuration(lastUpdateTime).getOrElse(1L) + val formatDuration = + taskData.taskDuration(lastUpdateTime).map(d => UIUtils.formatDuration(d)).getOrElse("") val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) @@ -1016,7 +1021,8 @@ private[ui] class TaskDataSource( info.speculative, info.status, info.taskLocality.toString, - s"${info.executorId} / ${info.host}", + info.executorId, + info.host, info.launchTime, duration, formatDuration, @@ -1046,7 +1052,8 @@ private[ui] class TaskDataSource( case "Attempt" => Ordering.by(_.attempt) case "Status" => Ordering.by(_.status) case "Locality Level" => Ordering.by(_.taskLocality) - case "Executor ID / Host" => Ordering.by(_.executorIdAndHost) + case "Executor ID" => Ordering.by(_.executorId) + case "Host" => Ordering.by(_.host) case "Launch Time" => Ordering.by(_.launchTime) case "Duration" => Ordering.by(_.duration) case "Scheduler Delay" => Ordering.by(_.schedulerDelay) @@ -1150,6 +1157,7 @@ private[ui] class TaskPagedTable( hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean, + lastUpdateTime: Option[Long], currentTime: Long, pageSize: Int, sortColumn: String, @@ -1175,6 +1183,7 @@ private[ui] class TaskPagedTable( hasShuffleRead, hasShuffleWrite, hasBytesSpilled, + lastUpdateTime, currentTime, pageSize, sortColumn, @@ -1199,7 +1208,7 @@ private[ui] class TaskPagedTable( val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""), ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), ("GC Time", ""), @@ -1270,8 +1279,9 @@ private[ui] class TaskPagedTable( {if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString} {task.status} {task.taskLocality} + {task.executorId} -
{task.executorIdAndHost}
+
{task.host}
{ task.logs.map { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 256b726fa7ee..f0a12a28de06 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -42,15 +42,17 @@ private[ui] class StageTableBase( isFairScheduler: Boolean, killEnabled: Boolean, isFailedStage: Boolean) { - val allParameters = request.getParameterMap().asScala.toMap + // stripXSS is called to remove suspicious characters used in XSS attacks + val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) - val parameterStagePage = request.getParameter(stageTag + ".page") - val parameterStageSortColumn = request.getParameter(stageTag + ".sort") - val parameterStageSortDesc = request.getParameter(stageTag + ".desc") - val parameterStagePageSize = request.getParameter(stageTag + ".pageSize") - val parameterStagePrevPageSize = request.getParameter(stageTag + ".prevPageSize") + val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page")) + val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort")) + val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc")) + val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize")) + val parameterStagePrevPageSize = + UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize")) val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1) val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn => @@ -58,7 +60,7 @@ private[ui] class StageTableBase( }.getOrElse("Stage Id") val stageSortDesc = Option(parameterStageSortDesc).map(_.toBoolean).getOrElse( // New stages should be shown above old jobs by default. - if (stageSortColumn == "Stage Id") true else false + stageSortColumn == "Stage Id" ) val stagePageSize = Option(parameterStagePageSize).map(_.toInt).getOrElse(100) val stagePrevPageSize = Option(parameterStagePrevPageSize).map(_.toInt) @@ -512,4 +514,3 @@ private[ui] class StageDataSource( } } } - diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 181465bdf960..0787ea662590 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all stages in the given SparkContext. */ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { @@ -30,6 +30,7 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" val progressListener = parent.jobProgressListener val operationGraphListener = parent.operationGraphListener val executorsListener = parent.executorsListener + val lastUpdateTime = parent.lastUpdateTime attachPage(new AllStagesPage(this)) attachPage(new StagePage(this)) @@ -39,7 +40,8 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - val stageId = Option(request.getParameter("id")).map(_.toInt) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) stageId.foreach { id => if (progressListener.activeStages.contains(id)) { sc.foreach(_.cancelStage(id, "killed via the Web UI")) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index ac1a74ad8029..5acec0d0f54c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -20,6 +20,8 @@ package org.apache.spark.ui.jobs import scala.collection.mutable import scala.collection.mutable.{HashMap, LinkedHashMap} +import com.google.common.collect.Interners + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor._ import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} @@ -62,6 +64,7 @@ private[spark] object UIData { var numTasks: Int = 0, var numActiveTasks: Int = 0, var numCompletedTasks: Int = 0, + var completedIndices: OpenHashSet[(Int, Int)] = new OpenHashSet[(Int, Int)](), var numSkippedTasks: Int = 0, var numFailedTasks: Int = 0, var reasonToNumKilled: Map[String, Int] = Map.empty, @@ -94,6 +97,7 @@ private[spark] object UIData { var memoryBytesSpilled: Long = _ var diskBytesSpilled: Long = _ var isBlacklisted: Int = _ + var lastUpdateTime: Option[Long] = None var schedulingPool: String = "" var description: Option[String] = None @@ -106,15 +110,15 @@ private[spark] object UIData { def hasOutput: Boolean = outputBytes > 0 def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 - def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0 + def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 || diskBytesSpilled > 0 } /** * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. */ - class TaskUIData private( - private var _taskInfo: TaskInfo, - private var _metrics: Option[TaskMetricsUIData]) { + class TaskUIData private(private var _taskInfo: TaskInfo) { + + private[this] var _metrics: Option[TaskMetricsUIData] = Some(TaskMetricsUIData.EMPTY) var errorMessage: Option[String] = None @@ -127,12 +131,12 @@ private[spark] object UIData { } def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = { - _metrics = TaskUIData.toTaskMetricsUIData(metrics) + _metrics = metrics.map(TaskMetricsUIData.fromTaskMetrics) } - def taskDuration: Option[Long] = { + def taskDuration(lastUpdateTime: Option[Long] = None): Option[Long] = { if (taskInfo.status == "RUNNING") { - Some(_taskInfo.timeRunning(System.currentTimeMillis)) + Some(_taskInfo.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis))) } else { _metrics.map(_.executorRunTime) } @@ -140,28 +144,16 @@ private[spark] object UIData { } object TaskUIData { - def apply(taskInfo: TaskInfo, metrics: Option[TaskMetrics]): TaskUIData = { - new TaskUIData(dropInternalAndSQLAccumulables(taskInfo), toTaskMetricsUIData(metrics)) + + private val stringInterner = Interners.newWeakInterner[String]() + + /** String interning to reduce the memory usage. */ + private def weakIntern(s: String): String = { + stringInterner.intern(s) } - private def toTaskMetricsUIData(metrics: Option[TaskMetrics]): Option[TaskMetricsUIData] = { - metrics.map { m => - TaskMetricsUIData( - executorDeserializeTime = m.executorDeserializeTime, - executorDeserializeCpuTime = m.executorDeserializeCpuTime, - executorRunTime = m.executorRunTime, - executorCpuTime = m.executorCpuTime, - resultSize = m.resultSize, - jvmGCTime = m.jvmGCTime, - resultSerializationTime = m.resultSerializationTime, - memoryBytesSpilled = m.memoryBytesSpilled, - diskBytesSpilled = m.diskBytesSpilled, - peakExecutionMemory = m.peakExecutionMemory, - inputMetrics = InputMetricsUIData(m.inputMetrics), - outputMetrics = OutputMetricsUIData(m.outputMetrics), - shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), - shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) - } + def apply(taskInfo: TaskInfo): TaskUIData = { + new TaskUIData(dropInternalAndSQLAccumulables(taskInfo)) } /** @@ -174,8 +166,8 @@ private[spark] object UIData { index = taskInfo.index, attemptNumber = taskInfo.attemptNumber, launchTime = taskInfo.launchTime, - executorId = taskInfo.executorId, - host = taskInfo.host, + executorId = weakIntern(taskInfo.executorId), + host = weakIntern(taskInfo.host), taskLocality = taskInfo.taskLocality, speculative = taskInfo.speculative ) @@ -206,6 +198,28 @@ private[spark] object UIData { shuffleReadMetrics: ShuffleReadMetricsUIData, shuffleWriteMetrics: ShuffleWriteMetricsUIData) + object TaskMetricsUIData { + def fromTaskMetrics(m: TaskMetrics): TaskMetricsUIData = { + TaskMetricsUIData( + executorDeserializeTime = m.executorDeserializeTime, + executorDeserializeCpuTime = m.executorDeserializeCpuTime, + executorRunTime = m.executorRunTime, + executorCpuTime = m.executorCpuTime, + resultSize = m.resultSize, + jvmGCTime = m.jvmGCTime, + resultSerializationTime = m.resultSerializationTime, + memoryBytesSpilled = m.memoryBytesSpilled, + diskBytesSpilled = m.diskBytesSpilled, + peakExecutionMemory = m.peakExecutionMemory, + inputMetrics = InputMetricsUIData(m.inputMetrics), + outputMetrics = OutputMetricsUIData(m.outputMetrics), + shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), + shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) + } + + val EMPTY: TaskMetricsUIData = fromTaskMetrics(TaskMetrics.empty) + } + case class InputMetricsUIData(bytesRead: Long, recordsRead: Long) object InputMetricsUIData { def apply(metrics: InputMetrics): InputMetricsUIData = { @@ -238,6 +252,7 @@ private[spark] object UIData { remoteBlocksFetched: Long, localBlocksFetched: Long, remoteBytesRead: Long, + remoteBytesReadToDisk: Long, localBytesRead: Long, fetchWaitTime: Long, recordsRead: Long, @@ -261,6 +276,7 @@ private[spark] object UIData { remoteBlocksFetched = metrics.remoteBlocksFetched, localBlocksFetched = metrics.localBlocksFetched, remoteBytesRead = metrics.remoteBytesRead, + remoteBytesReadToDisk = metrics.remoteBytesReadToDisk, localBytesRead = metrics.localBytesRead, fetchWaitTime = metrics.fetchWaitTime, recordsRead = metrics.recordsRead, @@ -269,7 +285,7 @@ private[spark] object UIData { ) } } - private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0) + private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0, 0) } case class ShuffleWriteMetricsUIData( diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 43bfe0aacf35..bb763248cd7e 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -26,7 +26,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.internal.Logging import org.apache.spark.scheduler.StageInfo -import org.apache.spark.storage.{RDDInfo, StorageLevel} +import org.apache.spark.storage.StorageLevel /** * A representation of a generic cluster graph used for storing information on RDD operations. diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index a1a0c729b924..e8ff08f7d88f 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -31,14 +31,15 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterBlockPage = request.getParameter("block.page") - val parameterBlockSortColumn = request.getParameter("block.sort") - val parameterBlockSortDesc = request.getParameter("block.desc") - val parameterBlockPageSize = request.getParameter("block.pageSize") - val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize") + val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page")) + val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort")) + val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc")) + val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize")) + val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize")) val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") @@ -50,7 +51,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) .getOrElse { // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage("RDD Not Found", Seq[Node](), parent) + return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) } // Worker table diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index a65ec75cc5db..f4a736d6d439 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -23,12 +23,9 @@ import java.util.{ArrayList, Collections} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong -import scala.collection.JavaConverters._ - import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} import org.apache.spark.scheduler.AccumulableInfo - private[spark] case class AccumulatorMetadata( id: Long, name: Option[String], @@ -68,7 +65,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { private def assertMetadataNotNull(): Unit = { if (metadata == null) { - throw new IllegalAccessError("The metadata of this accumulator has not been assigned yet.") + throw new IllegalStateException("The metadata of this accumulator has not been assigned yet.") } } @@ -84,10 +81,11 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { * Returns the name of this accumulator, can only be called after registration. */ final def name: Option[String] = { + assertMetadataNotNull() + if (atDriverSide) { - AccumulatorContext.get(id).flatMap(_.metadata.name) + metadata.name.orElse(AccumulatorContext.get(id).flatMap(_.metadata.name)) } else { - assertMetadataNotNull() metadata.name } } @@ -165,13 +163,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } val copyAcc = copyAndReset() assert(copyAcc.isZero, "copyAndReset must return a zero value copy") - val isInternalAcc = - (name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) || - getClass.getSimpleName == "SQLMetric" + val isInternalAcc = name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX) if (isInternalAcc) { // Do not serialize the name of internal accumulator and send it to executor. copyAcc.metadata = metadata.copy(name = None) } else { + // For non-internal accumulators, we still need to send the name because users may need to + // access the accumulator name at executor side, or they may keep the accumulators sent from + // executors and access the name when the registered accumulator is already garbage + // collected(e.g. SQLMetrics). copyAcc.metadata = metadata } copyAcc @@ -262,7 +262,7 @@ private[spark] object AccumulatorContext { // Since we are storing weak references, we must check whether the underlying data is valid. val acc = ref.get if (acc eq null) { - throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id") + throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") } acc } diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala index 1b2b1932e0c3..eff0aa4453f0 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -51,6 +51,10 @@ private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Orderin this } + def poll(): A = { + underlying.poll() + } + override def +=(elem1: A, elem2: A, elems: A*): this.type = { this += elem1 += elem2 ++= elems } diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala index 50dc948e6c41..a938cb07724c 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -20,8 +20,6 @@ package org.apache.spark.util import java.io.InputStream import java.nio.ByteBuffer -import org.apache.spark.storage.StorageUtils - /** * Reads data from a ByteBuffer. */ diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 489688cb0880..48a1d7b84b61 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -81,7 +81,7 @@ private[spark] object ClosureCleaner extends Logging { val stack = Stack[Class[_]](obj.getClass) while (!stack.isEmpty) { val cr = getClassReader(stack.pop()) - val set = Set[Class[_]]() + val set = Set.empty[Class[_]] cr.accept(new InnerClosureFinder(set), 0) for (cls <- set -- seen) { seen += cls @@ -180,16 +180,18 @@ private[spark] object ClosureCleaner extends Logging { val declaredFields = func.getClass.getDeclaredFields val declaredMethods = func.getClass.getDeclaredMethods - logDebug(" + declared fields: " + declaredFields.size) - declaredFields.foreach { f => logDebug(" " + f) } - logDebug(" + declared methods: " + declaredMethods.size) - declaredMethods.foreach { m => logDebug(" " + m) } - logDebug(" + inner classes: " + innerClasses.size) - innerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer classes: " + outerClasses.size) - outerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer objects: " + outerObjects.size) - outerObjects.foreach { o => logDebug(" " + o) } + if (log.isDebugEnabled) { + logDebug(" + declared fields: " + declaredFields.size) + declaredFields.foreach { f => logDebug(" " + f) } + logDebug(" + declared methods: " + declaredMethods.size) + declaredMethods.foreach { m => logDebug(" " + m) } + logDebug(" + inner classes: " + innerClasses.size) + innerClasses.foreach { c => logDebug(" " + c.getName) } + logDebug(" + outer classes: " + outerClasses.size) + outerClasses.foreach { c => logDebug(" " + c.getName) } + logDebug(" + outer objects: " + outerObjects.size) + outerObjects.foreach { o => logDebug(" " + o) } + } // Fail fast if we detect return statements in closures getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) @@ -201,7 +203,7 @@ private[spark] object ClosureCleaner extends Logging { // Initialize accessed fields with the outer classes first // This step is needed to associate the fields to the correct classes later for (cls <- outerClasses) { - accessedFields(cls) = Set[String]() + accessedFields(cls) = Set.empty[String] } // Populate accessed fields by visiting all fields and methods accessed by this and // all of its inner closures. If transitive cleaning is enabled, this may recursively diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8296c4294242..8406826a228d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -339,6 +339,7 @@ private[spark] object JsonProtocol { ("Local Blocks Fetched" -> taskMetrics.shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> taskMetrics.shuffleReadMetrics.fetchWaitTime) ~ ("Remote Bytes Read" -> taskMetrics.shuffleReadMetrics.remoteBytesRead) ~ + ("Remote Bytes Read To Disk" -> taskMetrics.shuffleReadMetrics.remoteBytesReadToDisk) ~ ("Local Bytes Read" -> taskMetrics.shuffleReadMetrics.localBytesRead) ~ ("Total Records Read" -> taskMetrics.shuffleReadMetrics.recordsRead) val shuffleWriteMetrics: JValue = @@ -695,7 +696,7 @@ private[spark] object JsonProtocol { val accumulatedValues = { Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) - case None => Seq[AccumulableInfo]() + case None => Seq.empty[AccumulableInfo] } } @@ -725,7 +726,7 @@ private[spark] object JsonProtocol { val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) - case None => Seq[AccumulableInfo]() + case None => Seq.empty[AccumulableInfo] } val taskInfo = @@ -804,6 +805,8 @@ private[spark] object JsonProtocol { readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) + Utils.jsonOption(readJson \ "Remote Bytes Read To Disk") + .foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])} readMetrics.incLocalBytesRead( Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index fa5ad4e8d81e..76a56298aaeb 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal +import com.codahale.metrics.Timer + import org.apache.spark.internal.Logging /** @@ -30,14 +32,22 @@ import org.apache.spark.internal.Logging */ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { + private[this] val listenersPlusTimers = new CopyOnWriteArrayList[(L, Option[Timer])] + // Marked `private[spark]` for access in tests. - private[spark] val listeners = new CopyOnWriteArrayList[L] + private[spark] def listeners = listenersPlusTimers.asScala.map(_._1).asJava + + /** + * Returns a CodaHale metrics Timer for measuring the listener's event processing time. + * This method is intended to be overridden by subclasses. + */ + protected def getTimer(listener: L): Option[Timer] = None /** * Add a listener to listen events. This method is thread-safe and can be called in any thread. */ final def addListener(listener: L): Unit = { - listeners.add(listener) + listenersPlusTimers.add((listener, getTimer(listener))) } /** @@ -45,7 +55,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * in any thread. */ final def removeListener(listener: L): Unit = { - listeners.remove(listener) + listenersPlusTimers.asScala.find(_._1 eq listener).foreach { listenerAndTimer => + listenersPlusTimers.remove(listenerAndTimer) + } } /** @@ -56,14 +68,25 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. - val iter = listeners.iterator + val iter = listenersPlusTimers.iterator while (iter.hasNext) { - val listener = iter.next() + val listenerAndMaybeTimer = iter.next() + val listener = listenerAndMaybeTimer._1 + val maybeTimer = listenerAndMaybeTimer._2 + val maybeTimerContext = if (maybeTimer.isDefined) { + maybeTimer.get.time() + } else { + null + } try { doPostEvent(listener, event) } catch { case NonFatal(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) + } finally { + if (maybeTimerContext != null) { + maybeTimerContext.stop() + } } } } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 46a5cb2cff5a..e5cccf39f945 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -28,7 +28,7 @@ private[spark] object RpcUtils { def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") + Utils.checkHost(driverHost) rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index 95bf3f58bc77..e0f5af5250e7 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -20,11 +20,12 @@ package org.apache.spark.util import org.apache.spark.internal.Logging /** - * The default uncaught exception handler for Executors terminates the whole process, to avoid - * getting into a bad state indefinitely. Since Executors are relatively lightweight, it's better - * to fail fast when things go wrong. + * The default uncaught exception handler for Spark daemons. It terminates the whole process for + * any Errors, and also terminates the process for Exceptions when the exitOnException flag is true. + * + * @param exitOnUncaughtException Whether to exit the process on UncaughtException. */ -private[spark] object SparkUncaughtExceptionHandler +private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: Boolean = true) extends Thread.UncaughtExceptionHandler with Logging { override def uncaughtException(thread: Thread, exception: Throwable) { @@ -40,7 +41,7 @@ private[spark] object SparkUncaughtExceptionHandler if (!ShutdownHookManager.inShutdown()) { if (exception.isInstanceOf[OutOfMemoryError]) { System.exit(SparkExitCode.OOM) - } else { + } else if (exitOnUncaughtException) { System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 1aa4456ed01b..81aaf79db0c1 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -206,4 +206,25 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitresult + + // scalastyle:off awaitready + /** + * Preferred alternative to `Await.ready()`. + * + * @see [[awaitResult]] + */ + @throws(classOf[SparkException]) + def awaitReady[T](awaitable: Awaitable[T], atMost: Duration): awaitable.type = { + try { + // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. + // See SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + awaitable.ready(atMost)(awaitPermission) + } catch { + // TimeoutException is thrown in the current thread, so not need to warp the exception. + case NonFatal(t) if !t.isInstanceOf[TimeoutException] => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } + // scalastyle:on awaitready } diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 27922b31949b..6a58ec142dd7 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -55,9 +55,6 @@ private[spark] class UninterruptibleThread( * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning * from `f`. * - * If this method finds that `interrupt` is called before calling `f` and it's not inside another - * `runUninterruptibly`, it will throw `InterruptedException`. - * * Note: this method should be called only in `this` thread. */ def runUninterruptibly[T](f: => T): T = { @@ -73,12 +70,7 @@ private[spark] class UninterruptibleThread( uninterruptibleLock.synchronized { // Clear the interrupted status if it's set. - if (Thread.interrupted() || shouldInterruptThread) { - shouldInterruptThread = false - // Since it's interrupted, we don't need to run `f` which may be a long computation. - // Throw InterruptedException as we don't have a T to return. - throw new InterruptedException() - } + shouldInterruptThread = Thread.interrupted() || shouldInterruptThread uninterruptible = true } try { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4d37db96dfc3..836e33c36d9a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,7 +22,7 @@ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInf import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.Channels +import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import java.util.{Locale, Properties, Random, UUID} @@ -60,7 +60,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.util.logging.RollingFileAppender /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -77,6 +76,8 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() + private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler + /** * Define a default value for driver memory here since this value is referenced across the code * base and nearly all files already use Utils.scala @@ -319,41 +320,22 @@ private[spark] object Utils extends Logging { * copying is disabled by default unless explicitly set transferToEnabled as true, * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. */ - def copyStream(in: InputStream, - out: OutputStream, - closeStreams: Boolean = false, - transferToEnabled: Boolean = false): Long = - { - var count = 0L + def copyStream( + in: InputStream, + out: OutputStream, + closeStreams: Boolean = false, + transferToEnabled: Boolean = false): Long = { tryWithSafeFinally { if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. val inChannel = in.asInstanceOf[FileInputStream].getChannel() val outChannel = out.asInstanceOf[FileOutputStream].getChannel() - val initialPos = outChannel.position() val size = inChannel.size() - - // In case transferTo method transferred less data than we have required. - while (count < size) { - count += inChannel.transferTo(count, size - count, outChannel) - } - - // Check the position after transferTo loop to see if it is in the right position and - // give user information if not. - // Position will not be increased to the expected length after calling transferTo in - // kernel version 2.6.32, this issue can be seen in - // https://bugs.openjdk.java.net/browse/JDK-7052359 - // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - val finalPos = outChannel.position() - assert(finalPos == initialPos + size, - s""" - |Current position $finalPos do not equal to expected position ${initialPos + size} - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) + copyFileStreamNIO(inChannel, outChannel, 0, size) + size } else { + var count = 0L val buf = new Array[Byte](8192) var n = 0 while (n != -1) { @@ -363,8 +345,8 @@ private[spark] object Utils extends Logging { count += n } } + count } - count } { if (closeStreams) { try { @@ -376,6 +358,37 @@ private[spark] object Utils extends Logging { } } + def copyFileStreamNIO( + input: FileChannel, + output: FileChannel, + startPosition: Long, + bytesToCopy: Long): Unit = { + val initialPos = output.position() + var count = 0L + // In case transferTo method transferred less data than we have required. + while (count < bytesToCopy) { + count += input.transferTo(count + startPosition, bytesToCopy - count, output) + } + assert(count == bytesToCopy, + s"request to copy $bytesToCopy bytes, but actually copied $count bytes.") + + // Check the position after transferTo loop to see if it is in the right position and + // give user information if not. + // Position will not be increased to the expected length after calling transferTo in + // kernel version 2.6.32, this issue can be seen in + // https://bugs.openjdk.java.net/browse/JDK-7052359 + // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). + val finalPos = output.position() + val expectedPos = initialPos + bytesToCopy + assert(finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) + } + /** * Construct a URI container information used for authentication. * This also sets the default authenticator to properly negotiation the @@ -436,7 +449,7 @@ private[spark] object Utils extends Logging { securityMgr: SecurityManager, hadoopConf: Configuration, timestamp: Long, - useCache: Boolean) { + useCache: Boolean): File = { val fileName = decodeFileNameInURI(new URI(url)) val targetFile = new File(targetDir, fileName) val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) @@ -485,6 +498,8 @@ private[spark] object Utils extends Logging { if (isWindows) { FileUtil.chmod(targetFile.getAbsolutePath, "u+r") } + + targetFile } /** @@ -624,13 +639,13 @@ private[spark] object Utils extends Logging { * Throws SparkException if the target file already exists and has different contents than * the requested file. */ - private def doFetchFile( + def doFetchFile( url: String, targetDir: File, filename: String, conf: SparkConf, securityMgr: SecurityManager, - hadoopConf: Configuration) { + hadoopConf: Configuration): File = { val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) @@ -674,6 +689,8 @@ private[spark] object Utils extends Logging { fetchHcfsFile(path, targetDir, fs, conf, hadoopConf, fileOverwrite, filename = Some(filename)) } + + targetFile } /** @@ -923,6 +940,13 @@ private[spark] object Utils extends Logging { customHostname = Some(hostname) } + /** + * Get the local machine's FQDN. + */ + def localCanonicalHostName(): String = { + customHostname.getOrElse(localIpAddress.getCanonicalHostName) + } + /** * Get the local machine's hostname. */ @@ -937,12 +961,13 @@ private[spark] object Utils extends Logging { customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } - def checkHost(host: String, message: String = "") { - assert(host.indexOf(':') == -1, message) + def checkHost(host: String) { + assert(host != null && host.indexOf(':') == -1, s"Expected hostname (not IP) but got $host") } - def checkHostPort(hostPort: String, message: String = "") { - assert(hostPort.indexOf(':') != -1, message) + def checkHostPort(hostPort: String) { + assert(hostPort != null && hostPort.indexOf(':') != -1, + s"Expected host and port but got $hostPort") } // Typically, this will be of order of number of nodes in cluster @@ -990,6 +1015,15 @@ private[spark] object Utils extends Logging { } } + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } + /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. @@ -1014,7 +1048,9 @@ private[spark] object Utils extends Logging { ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { - if (!file.delete()) { + if (file.delete()) { + logTrace(s"${file.getAbsolutePath} has been deleted") + } else { // Delete can also fail if the file simply did not exist if (file.exists()) { throw new IOException("Failed to delete: " + file.getAbsolutePath) @@ -1157,16 +1193,17 @@ private[spark] object Utils extends Logging { val second = 1000 val minute = 60 * second val hour = 60 * minute + val locale = Locale.US ms match { case t if t < second => - "%d ms".format(t) + "%d ms".formatLocal(locale, t) case t if t < minute => - "%.1f s".format(t.toFloat / second) + "%.1f s".formatLocal(locale, t.toFloat / second) case t if t < hour => - "%.1f m".format(t.toFloat / minute) + "%.1f m".formatLocal(locale, t.toFloat / minute) case t => - "%.2f h".format(t.toFloat / hour) + "%.2f h".formatLocal(locale, t.toFloat / hour) } } @@ -1251,7 +1288,7 @@ private[spark] object Utils extends Logging { block } catch { case e: ControlThrowable => throw e - case t: Throwable => SparkUncaughtExceptionHandler.uncaughtException(t) + case t: Throwable => sparkUncaughtExceptionHandler.uncaughtException(t) } } @@ -1334,14 +1371,10 @@ private[spark] object Utils extends Logging { try { finallyBlock } catch { - case t: Throwable => - if (originalThrowable != null) { - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in finally: " + t.getMessage, t) - throw originalThrowable - } else { - throw t - } + case t: Throwable if (originalThrowable != null && originalThrowable != t) => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) + throw originalThrowable } } } @@ -1373,22 +1406,20 @@ private[spark] object Utils extends Logging { catchBlock } catch { case t: Throwable => - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in catch: " + t.getMessage, t) + if (originalThrowable != t) { + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in catch: ${t.getMessage}", t) + } } throw originalThrowable } finally { try { finallyBlock } catch { - case t: Throwable => - if (originalThrowable != null) { - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in finally: " + t.getMessage, t) - throw originalThrowable - } else { - throw t - } + case t: Throwable if (originalThrowable != null && originalThrowable != t) => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) + throw originalThrowable } } } @@ -1424,7 +1455,7 @@ private[spark] object Utils extends Logging { var firstUserFile = "" var firstUserLine = 0 var insideSpark = true - var callStack = new ArrayBuffer[String]() :+ "" + val callStack = new ArrayBuffer[String]() :+ "" Thread.currentThread.getStackTrace().foreach { ste: StackTraceElement => // When running under some profilers, the current stack trace might contain some bogus @@ -2419,7 +2450,7 @@ private[spark] object Utils extends Logging { .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } - val EMPTY_USER_GROUPS = Set[String]() + val EMPTY_USER_GROUPS = Set.empty[String] // Returns the groups to which the current user belongs. def getCurrentUserGroups(sparkConf: SparkConf, username: String): Set[String] = { @@ -2568,25 +2599,30 @@ private[spark] object Utils extends Logging { * Unions two comma-separated lists of files and filters out empty strings. */ def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = { - var allFiles = Set[String]() + var allFiles = Set.empty[String] leftList.foreach { value => allFiles ++= value.split(",") } rightList.foreach { value => allFiles ++= value.split(",") } allFiles.filter { _.nonEmpty } } /** - * In YARN mode this method returns a union of the jar files pointed by "spark.jars" and the - * "spark.yarn.dist.jars" properties, while in other modes it returns the jar files pointed by - * only the "spark.jars" property. + * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute + * these jars through file server. In the YARN mode, it will return an empty list, since YARN + * has its own mechanism to distribute jars. */ - def getUserJars(conf: SparkConf, isShell: Boolean = false): Seq[String] = { + def getUserJars(conf: SparkConf): Seq[String] = { val sparkJars = conf.getOption("spark.jars") - if (conf.get("spark.master") == "yarn" && isShell) { - val yarnJars = conf.getOption("spark.yarn.dist.jars") - unionFileLists(sparkJars, yarnJars).toSeq - } else { - sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten - } + sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten + } + + /** + * Return the local jar files which will be added to REPL's classpath. These jar files are + * specified by --jars (spark.jars) or --packages, remote jars will be downloaded to local by + * SparkSubmit at first. + */ + def getLocalUserJarsForShell(conf: SparkConf): Seq[String] = { + val localJars = conf.getOption("spark.repl.local.jars") + localJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten } private[spark] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" @@ -2604,9 +2640,12 @@ private[spark] object Utils extends Logging { * Redact the sensitive information in the given string. */ def redact(conf: SparkConf, text: String): String = { - if (text == null || text.isEmpty || !conf.contains(STRING_REDACTION_PATTERN)) return text - val regex = conf.get(STRING_REDACTION_PATTERN).get - regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + if (text == null || text.isEmpty || conf == null || !conf.contains(STRING_REDACTION_PATTERN)) { + text + } else { + val regex = conf.get(STRING_REDACTION_PATTERN).get + regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + } } private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { @@ -2645,6 +2684,9 @@ private[spark] object Utils extends Logging { redact(redactionPattern, kvs.toArray) } + def stringToSeq(str: String): Seq[String] = { + str.split(",").map(_.trim()).filter(_.nonEmpty) + } } private[util] object CallerContext extends Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index 4d43d8d5cc8d..f5d2fa14e49c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -126,22 +126,22 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable /** Increase our size to newSize and grow the backing array if needed. */ private def growToSize(newSize: Int): Unit = { - if (newSize < 0) { - throw new UnsupportedOperationException("Can't grow buffer past Int.MaxValue elements") + // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat + // smaller. Be conservative and lower the cap a little. + val arrayMax = Int.MaxValue - 8 + if (newSize < 0 || newSize - 2 > arrayMax) { + throw new UnsupportedOperationException(s"Can't grow buffer past $arrayMax elements") } val capacity = if (otherElements != null) otherElements.length + 2 else 2 if (newSize > capacity) { - var newArrayLen = 8 + var newArrayLen = 8L while (newSize - 2 > newArrayLen) { newArrayLen *= 2 - if (newArrayLen == Int.MinValue) { - // Prevent overflow if we double from 2^30 to 2^31, which will become Int.MinValue. - // Note that we set the new array length to Int.MaxValue - 2 so that our capacity - // calculation above still gives a positive integer. - newArrayLen = Int.MaxValue - 2 - } } - val newArray = new Array[T](newArrayLen) + if (newArrayLen > arrayMax) { + newArrayLen = arrayMax + } + val newArray = new Array[T](newArrayLen.toInt) if (otherElements != null) { System.arraycopy(otherElements, 0, newArray, 0, otherElements.length) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8aafda5e45d5..6f5b5bb3652d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -18,6 +18,8 @@ package org.apache.spark.util.collection import java.io._ +import java.nio.channels.{Channels, FileChannel} +import java.nio.file.StandardOpenOption import java.util.Comparator import scala.collection.BufferedIterator @@ -30,7 +32,6 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerManager} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.CompletionIterator @@ -460,7 +461,7 @@ class ExternalAppendOnlyMap[K, V, C]( ) private var batchIndex = 0 // Which batch we're in - private var fileStream: FileInputStream = null + private var fileChannel: FileChannel = null // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams @@ -477,14 +478,14 @@ class ExternalAppendOnlyMap[K, V, C]( if (batchIndex < batchOffsets.length - 1) { if (deserializeStream != null) { deserializeStream.close() - fileStream.close() + fileChannel.close() deserializeStream = null - fileStream = null + fileChannel = null } val start = batchOffsets(batchIndex) - fileStream = new FileInputStream(file) - fileStream.getChannel.position(start) + fileChannel = FileChannel.open(file.toPath, StandardOpenOption.READ) + fileChannel.position(start) batchIndex += 1 val end = batchOffsets(batchIndex) @@ -492,7 +493,8 @@ class ExternalAppendOnlyMap[K, V, C]( assert(end >= start, "start = " + start + ", end = " + end + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) - val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val bufferedStream = new BufferedInputStream( + ByteStreams.limit(Channels.newInputStream(fileChannel), end - start)) val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream) ser.deserializeStream(wrappedStream) } else { @@ -552,9 +554,9 @@ class ExternalAppendOnlyMap[K, V, C]( ds.close() deserializeStream = null } - if (fileStream != null) { - fileStream.close() - fileStream = null + if (fileChannel != null) { + fileChannel.close() + fileChannel = null } if (file.exists()) { if (!file.delete()) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d..3593cfd50778 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -18,6 +18,8 @@ package org.apache.spark.util.collection import java.io._ +import java.nio.channels.{Channels, FileChannel} +import java.nio.file.StandardOpenOption import java.util.Comparator import scala.collection.mutable @@ -492,7 +494,7 @@ private[spark] class ExternalSorter[K, V, C]( // Intermediate file and deserializer streams that read from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - var fileStream: FileInputStream = null + var fileChannel: FileChannel = null var deserializeStream = nextBatchStream() // Also sets fileStream var nextItem: (K, C) = null @@ -505,14 +507,14 @@ private[spark] class ExternalSorter[K, V, C]( if (batchId < batchOffsets.length - 1) { if (deserializeStream != null) { deserializeStream.close() - fileStream.close() + fileChannel.close() deserializeStream = null - fileStream = null + fileChannel = null } val start = batchOffsets(batchId) - fileStream = new FileInputStream(spill.file) - fileStream.getChannel.position(start) + fileChannel = FileChannel.open(spill.file.toPath, StandardOpenOption.READ) + fileChannel.position(start) batchId += 1 val end = batchOffsets(batchId) @@ -520,7 +522,8 @@ private[spark] class ExternalSorter[K, V, C]( assert(end >= start, "start = " + start + ", end = " + end + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) - val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val bufferedStream = new BufferedInputStream( + ByteStreams.limit(Channels.newInputStream(fileChannel), end - start)) val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream) serInstance.deserializeStream(wrappedStream) @@ -610,7 +613,7 @@ private[spark] class ExternalSorter[K, V, C]( batchId = batchOffsets.length // Prevent reading any other batch val ds = deserializeStream deserializeStream = null - fileStream = null + fileChannel = null if (ds != null) { ds.close() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index f5844d5353be..b755e5da5168 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -25,7 +25,7 @@ import org.apache.spark.util.collection.WritablePartitionedPairCollection._ * Append-only buffer of key-value pairs, each with a corresponding partition ID, that keeps track * of its estimated size in bytes. * - * The buffer can support up to `1073741823 (2 ^ 30 - 1)` elements. + * The buffer can support up to 1073741819 elements. */ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) extends WritablePartitionedPairCollection[K, V] with SizeTracker @@ -59,7 +59,7 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_CAPACITY} elements") } val newCapacity = - if (capacity * 2 < 0 || capacity * 2 > MAXIMUM_CAPACITY) { // Overflow + if (capacity * 2 > MAXIMUM_CAPACITY) { // Overflow MAXIMUM_CAPACITY } else { capacity * 2 @@ -96,5 +96,7 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) } private object PartitionedPairBuffer { - val MAXIMUM_CAPACITY = Int.MaxValue / 2 // 2 ^ 30 - 1 + // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat + // smaller. Be conservative and lower the cap a little. + val MAXIMUM_CAPACITY: Int = (Int.MaxValue - 8) / 2 } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 2f905c8af0f6..c28570fb2456 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -24,6 +24,8 @@ import java.nio.channels.WritableByteChannel import com.google.common.primitives.UnsignedBytes import io.netty.buffer.{ByteBuf, Unpooled} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.storage.StorageUtils @@ -40,6 +42,11 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { require(chunks != null, "chunks must not be null") require(chunks.forall(_.position() == 0), "chunks' positions must be 0") + // Chunk size in bytes + private val bufferWriteChunkSize = + Option(SparkEnv.get).map(_.conf.get(config.BUFFER_WRITE_CHUNK_SIZE)) + .getOrElse(config.BUFFER_WRITE_CHUNK_SIZE.defaultValue.get).toInt + private[this] var disposed: Boolean = false /** @@ -56,7 +63,9 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining > 0) { + while (bytes.remaining() > 0) { + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position + ioSize) channel.write(bytes) } } @@ -66,7 +75,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { * Wrap this buffer to view it as a Netty ByteBuf. */ def toNetty: ByteBuf = { - Unpooled.wrappedBuffer(getChunks(): _*) + Unpooled.wrappedBuffer(chunks.length, getChunks(): _*) } /** diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index fdb1495899bc..2f9ad4c8cc3e 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -94,7 +94,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi /** Open the file output stream */ protected def openFile() { - outputStream = new FileOutputStream(file, false) + outputStream = new FileOutputStream(file, true) logDebug(s"Opened file $file") } @@ -125,16 +125,16 @@ private[spark] object FileAppender extends Logging { val validatedParams: Option[(Long, String)] = rollingInterval match { case "daily" => logInfo(s"Rolling executor logs enabled for $file with daily rolling") - Some(24 * 60 * 60 * 1000L, "--yyyy-MM-dd") + Some((24 * 60 * 60 * 1000L, "--yyyy-MM-dd")) case "hourly" => logInfo(s"Rolling executor logs enabled for $file with hourly rolling") - Some(60 * 60 * 1000L, "--yyyy-MM-dd--HH") + Some((60 * 60 * 1000L, "--yyyy-MM-dd--HH")) case "minutely" => logInfo(s"Rolling executor logs enabled for $file with rolling every minute") - Some(60 * 1000L, "--yyyy-MM-dd--HH-mm") + Some((60 * 1000L, "--yyyy-MM-dd--HH-mm")) case IntParam(seconds) => logInfo(s"Rolling executor logs enabled for $file with rolling $seconds seconds") - Some(seconds * 1000L, "--yyyy-MM-dd--HH-mm-ss") + Some((seconds * 1000L, "--yyyy-MM-dd--HH-mm-ss")) case _ => logWarning(s"Illegal interval for rolling executor logs [$rollingInterval], " + s"rolling logs not enabled") diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala index 1be31e88ab68..51feccfb8342 100644 --- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -55,14 +55,16 @@ class TaskCompletionListenerException( extends RuntimeException { override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } + - previousError.map { e => + val listenerErrorMessage = + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + val previousErrorMessage = previousError.map { e => "\n\nPrevious exception in task: " + e.getMessage + "\n" + e.getStackTrace.mkString("\t", "\n\t", "") }.getOrElse("") + listenerErrorMessage + previousErrorMessage } } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java similarity index 87% rename from core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java rename to core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index 2c1a34a60759..3440e1aea2f4 100644 --- a/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -31,11 +31,13 @@ /** * Tests functionality of {@link NioBufferedFileInputStream} */ -public class NioBufferedFileInputStreamSuite { +public abstract class GenericFileInputStreamSuite { private byte[] randomBytes; - private File inputFile; + protected File inputFile; + + protected InputStream inputStream; @Before public void setUp() throws IOException { @@ -52,7 +54,6 @@ public void tearDown() { @Test public void testReadOneByte() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < randomBytes.length; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -60,7 +61,6 @@ public void testReadOneByte() throws IOException { @Test public void testReadMultipleBytes() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); byte[] readBytes = new byte[8 * 1024]; int i = 0; while (i < randomBytes.length) { @@ -74,7 +74,6 @@ public void testReadMultipleBytes() throws IOException { @Test public void testBytesSkipped() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); assertEquals(1024, inputStream.skip(1024)); for (int i = 1024; i < randomBytes.length; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); @@ -83,7 +82,6 @@ public void testBytesSkipped() throws IOException { @Test public void testBytesSkippedAfterRead() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < 1024; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -95,7 +93,6 @@ public void testBytesSkippedAfterRead() throws IOException { @Test public void testNegativeBytesSkippedAfterRead() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < 1024; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -111,7 +108,6 @@ public void testNegativeBytesSkippedAfterRead() throws IOException { @Test public void testSkipFromFileChannel() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile, 10); // Since the buffer is smaller than the skipped bytes, this will guarantee // we skip from underlying file channel. assertEquals(1024, inputStream.skip(1024)); @@ -128,7 +124,6 @@ public void testSkipFromFileChannel() throws IOException { @Test public void testBytesSkippedAfterEOF() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); assertEquals(-1, inputStream.read()); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java similarity index 58% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala rename to core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java index 0c7205b3c665..211b33a1a9fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala +++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java @@ -14,17 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.io; -package org.apache.spark.sql.catalyst.util +import org.junit.Before; + +import java.io.IOException; /** - * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator - * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to - * `Row` in order to work around a spurious IntelliJ compiler error. This cannot be an abstract - * class because that leads to compilation errors under Scala 2.11. + * Tests functionality of {@link NioBufferedFileInputStream} */ -class AbstractScalaRowIterator[T] extends Iterator[T] { - override def hasNext: Boolean = throw new NotImplementedError +public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite { - override def next(): T = throw new NotImplementedError + @Before + public void setUp() throws IOException { + super.setUp(); + inputStream = new NioBufferedFileInputStream(inputFile); + } } diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java new file mode 100644 index 000000000000..918ddc4517ec --- /dev/null +++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.junit.Before; + +import java.io.IOException; + +/** + * Tests functionality of {@link NioBufferedFileInputStream} + */ +public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite { + + @Before + public void setUp() throws IOException { + super.setUp(); + inputStream = new ReadAheadInputStream( + new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024); + } +} diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 0c7712374085..ac4391e3ef99 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -21,7 +21,6 @@ import java.util.HashMap; import java.util.Map; -import org.junit.Before; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,12 +44,7 @@ public class SparkLauncherSuite { private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); private static final NamedThreadFactory TF = new NamedThreadFactory("SparkLauncherSuite-%d"); - private SparkLauncher launcher; - - @Before - public void configureLauncher() { - launcher = new SparkLauncher().setSparkHome(System.getProperty("spark.test.home")); - } + private final SparkLauncher launcher = new SparkLauncher(); @Test public void testSparkArgumentHandling() throws Exception { @@ -101,60 +95,6 @@ public void testSparkArgumentHandling() throws Exception { assertEquals("python3.5", launcher.builder.conf.get(package$.MODULE$.PYSPARK_PYTHON().key())); } - @Test(expected=IllegalStateException.class) - public void testRedirectTwiceFails() throws Exception { - launcher.setAppResource("fake-resource.jar") - .setMainClass("my.fake.class.Fake") - .redirectError() - .redirectError(ProcessBuilder.Redirect.PIPE) - .launch(); - } - - @Test(expected=IllegalStateException.class) - public void testRedirectToLogWithOthersFails() throws Exception { - launcher.setAppResource("fake-resource.jar") - .setMainClass("my.fake.class.Fake") - .redirectToLog("fakeLog") - .redirectError(ProcessBuilder.Redirect.PIPE) - .launch(); - } - - @Test - public void testRedirectErrorToOutput() throws Exception { - launcher.redirectError(); - assertTrue(launcher.redirectErrorStream); - } - - @Test - public void testRedirectsSimple() throws Exception { - launcher.redirectError(ProcessBuilder.Redirect.PIPE); - assertNotNull(launcher.errorStream); - assertEquals(launcher.errorStream.type(), ProcessBuilder.Redirect.Type.PIPE); - - launcher.redirectOutput(ProcessBuilder.Redirect.PIPE); - assertNotNull(launcher.outputStream); - assertEquals(launcher.outputStream.type(), ProcessBuilder.Redirect.Type.PIPE); - } - - @Test - public void testRedirectLastWins() throws Exception { - launcher.redirectError(ProcessBuilder.Redirect.PIPE) - .redirectError(ProcessBuilder.Redirect.INHERIT); - assertEquals(launcher.errorStream.type(), ProcessBuilder.Redirect.Type.INHERIT); - - launcher.redirectOutput(ProcessBuilder.Redirect.PIPE) - .redirectOutput(ProcessBuilder.Redirect.INHERIT); - assertEquals(launcher.outputStream.type(), ProcessBuilder.Redirect.Type.INHERIT); - } - - @Test - public void testRedirectToLog() throws Exception { - launcher.redirectToLog("fakeLogger"); - assertTrue(launcher.redirectToLog); - assertTrue(launcher.builder.getEffectiveConfig() - .containsKey(SparkLauncher.CHILD_PROCESS_LOGGER_NAME)); - } - @Test public void testChildProcLauncher() throws Exception { // This test is failed on Windows due to the failure of initiating executors @@ -175,11 +115,11 @@ public void testChildProcLauncher() throws Exception { .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) + .redirectError() .addAppArgs("proc"); final Process app = launcher.launch(); - new OutputRedirector(app.getInputStream(), TF); - new OutputRedirector(app.getErrorStream(), TF); + new OutputRedirector(app.getInputStream(), getClass().getName() + ".child", TF); assertEquals(0, app.waitFor()); } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index f53bc0b02bbf..46b0516e3614 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -54,6 +54,7 @@ public void encodePageNumberAndOffsetOffHeap() { final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); Assert.assertEquals(null, manager.getPage(encodedAddress)); Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); + manager.freePage(dataPage, c); } @Test diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 771d39016c18..5330a688e63e 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -154,7 +154,7 @@ private UnsafeExternalSorter newSorter() throws IOException { blockManager, serializerManager, taskContext, - recordComparator, + () -> recordComparator, prefixComparator, /* initialSize */ 1024, pageSizeBytes, @@ -395,7 +395,7 @@ public void forcedSpillingWithoutComparator() throws Exception { sorter.spill(); } } - UnsafeSorterIterator iter = sorter.getIterator(); + UnsafeSorterIterator iter = sorter.getIterator(0); for (int i = 0; i < n; i++) { iter.hasNext(); iter.loadNext(); @@ -405,6 +405,31 @@ public void forcedSpillingWithoutComparator() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void testDiskSpilledBytes() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); + } + // We will have at-least 2 memory pages allocated because of rounding happening due to + // integer division of pageSizeBytes and recordSize. + assertTrue(sorter.getNumberOfAllocatedPages() >= 2); + assertTrue(taskContext.taskMetrics().diskBytesSpilled() == 0); + UnsafeExternalSorter.SpillableIterator iter = + (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); + assertTrue(iter.spill() > 0); + assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0); + assertEquals(0, iter.spill()); + // Even if we did not spill second time, the disk spilled bytes should still be non-zero + assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; @@ -415,7 +440,7 @@ public void testPeakMemoryUsed() throws Exception { blockManager, serializerManager, taskContext, - recordComparator, + () -> recordComparator, prefixComparator, 1024, pageSizeBytes, @@ -454,5 +479,37 @@ public void testPeakMemoryUsed() throws Exception { } } + @Test + public void testGetIterator() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + for (int i = 0; i < 100; i++) { + insertNumber(sorter, i); + } + verifyIntIterator(sorter.getIterator(0), 0, 100); + verifyIntIterator(sorter.getIterator(79), 79, 100); + + sorter.spill(); + for (int i = 100; i < 200; i++) { + insertNumber(sorter, i); + } + sorter.spill(); + verifyIntIterator(sorter.getIterator(79), 79, 200); + + for (int i = 200; i < 300; i++) { + insertNumber(sorter, i); + } + verifyIntIterator(sorter.getIterator(79), 79, 300); + verifyIntIterator(sorter.getIterator(139), 139, 300); + verifyIntIterator(sorter.getIterator(279), 279, 300); + } + + private void verifyIntIterator(UnsafeSorterIterator iter, int start, int end) + throws IOException { + for (int i = start; i < end; i++) { + assert (iter.hasNext()); + iter.loadNext(); + assert (Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()) == i); + } + } } diff --git a/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java b/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java new file mode 100644 index 000000000000..7e9cc70d8651 --- /dev/null +++ b/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark; + +import java.io.*; + +import scala.collection.immutable.List; +import scala.collection.immutable.List$; +import scala.collection.immutable.Map; +import scala.collection.immutable.Map$; + +import org.junit.Test; + +import org.apache.spark.api.java.*; +import org.apache.spark.*; + +/** + * Java apps can uses both Java-friendly JavaSparkContext and Scala SparkContext. + */ +public class JavaSparkContextSuite implements Serializable { + + @Test + public void javaSparkContext() { + String[] jars = new String[] {}; + java.util.Map environment = new java.util.HashMap<>(); + + new JavaSparkContext(new SparkConf().setMaster("local").setAppName("name")).stop(); + new JavaSparkContext("local", "name", new SparkConf()).stop(); + new JavaSparkContext("local", "name").stop(); + new JavaSparkContext("local", "name", "sparkHome", "jarFile").stop(); + new JavaSparkContext("local", "name", "sparkHome", jars).stop(); + new JavaSparkContext("local", "name", "sparkHome", jars, environment).stop(); + } + + @Test + public void scalaSparkContext() { + List jars = List$.MODULE$.empty(); + Map environment = Map$.MODULE$.empty(); + + new SparkContext(new SparkConf().setMaster("local").setAppName("name")).stop(); + new SparkContext("local", "name", new SparkConf()).stop(); + new SparkContext("local", "name").stop(); + new SparkContext("local", "name", "sparkHome").stop(); + new SparkContext("local", "name", "sparkHome", jars).stop(); + new SparkContext("local", "name", "sparkHome", jars, environment).stop(); + } +} diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index 10902ab5a832..f2c3ec5da889 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -8,6 +8,7 @@ "duration" : 10671, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "endTimeEpoch" : 1479335620587, "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0 @@ -22,6 +23,7 @@ "duration" : 101795, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "endTimeEpoch" : 1479252138874, "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0 @@ -36,6 +38,7 @@ "duration" : 10505, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917391398, "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 @@ -51,6 +54,7 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380950, "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 @@ -62,6 +66,7 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380890, "startTimeEpoch" : 1430917380880, "lastUpdatedEpoch" : 0 @@ -77,6 +82,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1426633945177, "startTimeEpoch" : 1426633910242, "lastUpdatedEpoch" : 0 @@ -88,6 +94,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1426533945177, "startTimeEpoch" : 1426533910242, "lastUpdatedEpoch" : 0 @@ -102,6 +109,7 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1425081766912, "startTimeEpoch" : 1425081758277, "lastUpdatedEpoch" : 0 @@ -116,6 +124,7 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1422981788731, "startTimeEpoch" : 1422981779720, "lastUpdatedEpoch" : 0 @@ -130,6 +139,7 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1422981766912, "startTimeEpoch" : 1422981758277, "lastUpdatedEpoch" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index 10902ab5a832..c925c1dd8a4d 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -8,6 +8,7 @@ "duration" : 10671, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "endTimeEpoch" : 1479335620587, "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0 @@ -22,6 +23,7 @@ "duration" : 101795, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "endTimeEpoch" : 1479252138874, "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0 @@ -36,6 +38,7 @@ "duration" : 10505, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917391398, "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 @@ -51,6 +54,7 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380950, "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 @@ -62,6 +66,7 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380890, "startTimeEpoch" : 1430917380880, "lastUpdatedEpoch" : 0 @@ -77,6 +82,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1426633945177, "startTimeEpoch" : 1426633910242, "lastUpdatedEpoch" : 0 @@ -88,6 +94,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1426533945177, "startTimeEpoch" : 1426533910242, "lastUpdatedEpoch" : 0 @@ -102,6 +109,8 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", + "appSparkVersion" : "", "endTimeEpoch" : 1425081766912, "startTimeEpoch" : 1425081758277, "lastUpdatedEpoch" : 0 @@ -116,6 +125,7 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1422981788731, "startTimeEpoch" : 1422981779720, "lastUpdatedEpoch" : 0 @@ -130,6 +140,7 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1422981766912, "startTimeEpoch" : 1422981758277, "lastUpdatedEpoch" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index 8820c717f85d..cc0b2b0022bd 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -8,6 +8,7 @@ "duration" : 10671, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "endTimeEpoch" : 1479335620587, "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0 @@ -22,6 +23,7 @@ "duration" : 101795, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "endTimeEpoch" : 1479252138874, "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0 @@ -36,6 +38,7 @@ "duration" : 10505, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917391398, "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json index c3fe4db222ae..fa12413eeb0e 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json @@ -8,6 +8,7 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1422981766912, "startTimeEpoch" : 1422981758277, "lastUpdatedEpoch" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json index 8281fa75aa0d..a0d4a0d1c455 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json @@ -8,6 +8,7 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1422981788731, "startTimeEpoch" : 1422981779720, "lastUpdatedEpoch" : 0 @@ -22,6 +23,7 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1422981766912, "startTimeEpoch" : 1422981758277, "lastUpdatedEpoch" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/maxEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxEndDate_app_list_json_expectation.json index 1842f1888b78..dfa90010c6ca 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxEndDate_app_list_json_expectation.json @@ -9,6 +9,7 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, "endTimeEpoch" : 1430917380950 @@ -20,6 +21,7 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, "endTimeEpoch" : 1430917380890 @@ -35,6 +37,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, "endTimeEpoch" : 1426633945177 @@ -46,6 +49,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, "endTimeEpoch" : 1426533945177 @@ -60,6 +64,7 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, "endTimeEpoch" : 1425081766912 @@ -74,6 +79,7 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, "endTimeEpoch" : 1422981788731 @@ -88,6 +94,7 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, "endTimeEpoch" : 1422981766912 diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_and_maxEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_and_maxEndDate_app_list_json_expectation.json index 24f9f21ec650..3ebe60e2cd03 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_and_maxEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_and_maxEndDate_app_list_json_expectation.json @@ -9,6 +9,7 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, "endTimeEpoch" : 1430917380950 @@ -20,6 +21,7 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, "endTimeEpoch" : 1430917380890 @@ -35,6 +37,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, "endTimeEpoch" : 1426633945177 @@ -46,6 +49,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, "endTimeEpoch" : 1426533945177 diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 1930281f1a3e..5af50abd8533 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -8,6 +8,7 @@ "duration" : 10671, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "endTimeEpoch" : 1479335620587, "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0 @@ -22,6 +23,7 @@ "duration" : 101795, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "endTimeEpoch" : 1479252138874, "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0 @@ -36,6 +38,7 @@ "duration" : 10505, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917391398, "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 @@ -51,6 +54,7 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380950, "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 @@ -62,6 +66,7 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380890, "startTimeEpoch" : 1430917380880, "lastUpdatedEpoch" : 0 @@ -77,6 +82,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1426633945177, "startTimeEpoch" : 1426633910242, "lastUpdatedEpoch" : 0 @@ -88,6 +94,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1426533945177, "startTimeEpoch" : 1426533910242, "lastUpdatedEpoch" : 0 @@ -102,6 +109,7 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1425081766912, "startTimeEpoch" : 1425081758277, "lastUpdatedEpoch" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_and_maxEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_and_maxEndDate_app_list_json_expectation.json index 3745e8a09a98..74a7b40a5927 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_and_maxEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_and_maxEndDate_app_list_json_expectation.json @@ -9,6 +9,7 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, "endTimeEpoch" : 1430917380950 @@ -20,6 +21,7 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, "endTimeEpoch" : 1430917380890 @@ -35,6 +37,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, "endTimeEpoch" : 1426633945177 @@ -46,6 +49,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, "endTimeEpoch" : 1426533945177 diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index 05233db441ed..7f896c74b5be 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -8,6 +8,7 @@ "duration" : 10671, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0, "endTimeEpoch" : 1479335620587 @@ -22,6 +23,7 @@ "duration" : 101795, "sparkUser" : "jose", "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0, "endTimeEpoch" : 1479252138874 @@ -36,6 +38,7 @@ "duration" : 10505, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, "endTimeEpoch" : 1430917391398 @@ -51,6 +54,7 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, "endTimeEpoch" : 1430917380950 @@ -62,7 +66,7 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, - "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, "endTimeEpoch" : 1430917380890 diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json index e8ed96dc85f0..24ec6a163fc2 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json @@ -8,6 +8,7 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1422981788731, "startTimeEpoch" : 1422981779720, "lastUpdatedEpoch" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json index 88c601512d79..94b6d6dba76e 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json @@ -9,6 +9,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1426633945177, "startTimeEpoch" : 1426633910242, "lastUpdatedEpoch" : 0 @@ -20,6 +21,7 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, + "appSparkVersion" : "", "endTimeEpoch" : 1426533945177, "startTimeEpoch" : 1426533910242, "lastUpdatedEpoch" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index c2f450ba87c6..6fb40f6f1713 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -60,6 +60,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -105,6 +106,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -150,6 +152,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -195,6 +198,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -240,6 +244,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -285,6 +290,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -330,6 +336,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -375,6 +382,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 506859ae545b..f5a89a210764 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -60,6 +60,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -105,6 +106,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -150,6 +152,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -195,6 +198,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -240,6 +244,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -285,6 +290,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -330,6 +336,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -375,6 +382,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index f4cec68fbfdf..9b401b414f8d 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index 496a21c328da..2ebee66a6d7c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -38,6 +38,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -87,6 +88,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -136,6 +138,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -185,6 +188,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -234,6 +238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -283,6 +288,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -332,6 +338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -381,6 +388,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 4328dc753c5d..965a31a4104c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -38,6 +38,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -87,6 +88,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -136,6 +138,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -185,6 +188,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -234,6 +238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -283,6 +288,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -332,6 +338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -381,6 +388,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 8c571430f3a1..31132e156937 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -913,6 +933,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -957,6 +978,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1001,6 +1023,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1045,6 +1068,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1089,6 +1113,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1133,6 +1158,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1177,6 +1203,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1221,6 +1248,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1265,6 +1293,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1309,6 +1338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1353,6 +1383,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1397,6 +1428,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1441,6 +1473,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1485,6 +1518,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1529,6 +1563,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1573,6 +1608,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1617,6 +1653,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1661,6 +1698,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1705,6 +1743,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1749,6 +1788,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1793,6 +1833,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1837,6 +1878,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1881,6 +1923,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1925,6 +1968,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1969,6 +2013,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2013,6 +2058,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2057,6 +2103,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2101,6 +2148,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2145,6 +2193,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2189,6 +2238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 0bd614bdc756..6af1cfbeb8f7 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 0bd614bdc756..6af1cfbeb8f7 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index b58f1a51ba48..c26daf4b8d7b 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index 0ed609d5b7f9..f8e27703c0de 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 0.0, 0.0, 0.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 0.0, 0.0, 0.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index 6d230ac65377..a28bda16a956 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 100.0, 100.0, 100.0, 100.0, 100.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0, 1.0, 1.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 100.0, 100.0, 100.0, 100.0, 100.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index aea0f5413d8b..ede3eaed1d1d 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index a449926ee7dc..44b5f66efe33 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -69,6 +69,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -119,6 +120,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -169,6 +171,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -219,6 +222,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -269,6 +273,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -319,6 +324,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -369,6 +375,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -419,6 +426,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/fairscheduler-with-valid-data.xml b/core/src/test/resources/fairscheduler-with-valid-data.xml new file mode 100644 index 000000000000..3d882331835c --- /dev/null +++ b/core/src/test/resources/fairscheduler-with-valid-data.xml @@ -0,0 +1,35 @@ + + + + + + 3 + 1 + FIFO + + + 4 + 2 + FAIR + + + 2 + 3 + FAIR + + \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index ddbcb2d19dcb..3990ee1ec326 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -210,7 +210,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(ref.get.isEmpty) // Getting a garbage collected accum should throw error - intercept[IllegalAccessError] { + intercept[IllegalStateException] { AccumulatorContext.get(accId) } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index ee70a3399efe..48408ccc8f81 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -114,7 +114,7 @@ trait RDDCheckpointTester { self: SparkFunSuite => * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, * the generated RDD will remember the partitions and therefore potentially the whole lineage. * This function should be called only those RDD whose partitions refer to parent RDD's - * partitions (i.e., do not call it on simple RDD like MappedRDD). + * partitions (i.e., do not call it on simple RDDs). * * @param op an operation to run on the RDD * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints @@ -388,7 +388,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS // the parent RDD has been checkpointed and parent partitions have been changed. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD + checkpoint(ones, reliableCheckpoint) val cartesian = new CartesianRDD(sc, ones, ones) val splitBeforeCheckpoint = serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) @@ -411,7 +411,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS // Note that this test is very specific to the current implementation of // CoalescedRDDPartitions. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD + checkpoint(ones, reliableCheckpoint) val coalesced = new CoalescedRDD(ones, 2) val splitBeforeCheckpoint = serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 84f7f1fc8eb0..bea67b71a5a1 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark -import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers +import org.scalatest.concurrent.TimeLimits._ import org.scalatest.time.{Millis, Span} import org.apache.spark.security.EncryptionFunSuite diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 454b7e607a51..be80d278fcea 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark import java.io.File -import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.TimeLimits import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils -class DriverSuite extends SparkFunSuite with Timeouts { +class DriverSuite extends SparkFunSuite with TimeLimits { ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 4ea42fc7d5c2..a91e09b7cb69 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -49,6 +49,11 @@ class ExecutorAllocationManagerSuite contexts.foreach(_.stop()) } + private def post(bus: LiveListenerBus, event: SparkListenerEvent): Unit = { + bus.post(event) + bus.waitUntilEmpty(1000) + } + test("verify min/max executors") { val conf = new SparkConf() .setMaster("myDummyLocalExternalClusterManager") @@ -95,7 +100,7 @@ class ExecutorAllocationManagerSuite test("add executors") { sc = createSparkContext(1, 10, 1) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Keep adding until the limit is reached assert(numExecutorsTarget(manager) === 1) @@ -140,7 +145,7 @@ class ExecutorAllocationManagerSuite test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 5))) // Verify that we're capped at number of tasks in the stage assert(numExecutorsTarget(manager) === 0) @@ -156,10 +161,10 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that running a task doesn't affect the target - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 3))) + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) - sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 6) @@ -172,9 +177,9 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that re-running a task doesn't blow things up - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 3))) - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, 3))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 9) assert(numExecutorsToAdd(manager) === 2) @@ -183,15 +188,49 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that running a task once we're at our limit doesn't blow things up - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) assert(addExecutors(manager) === 0) assert(numExecutorsTarget(manager) === 10) } + test("add executors when speculative tasks added") { + sc = createSparkContext(0, 10, 0) + val manager = sc.executorAllocationManager.get + + // Verify that we're capped at number of tasks including the speculative ones in the stage + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) + assert(numExecutorsTarget(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 2))) + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 3) + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 5) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a task doesn't affect the target + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + assert(numExecutorsTarget(manager) === 5) + assert(addExecutors(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a speculative task doesn't affect the target + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-2", true))) + assert(numExecutorsTarget(manager) === 5) + assert(addExecutors(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + } + test("cancel pending executors when no longer needed") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, 5))) assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) @@ -202,15 +241,15 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 3) val task1Info = createTaskInfo(0, 0, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info)) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, task1Info)) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 2) val task2Info = createTaskInfo(1, 0, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, task2Info)) + post(sc.listenerBus, SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + post(sc.listenerBus, SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -314,10 +353,50 @@ class ExecutorAllocationManagerSuite assert(executorsPendingToRemove(manager).isEmpty) } + test ("Removing with various numExecutorsTarget condition") { + sc = createSparkContext(5, 12, 5) + val manager = sc.executorAllocationManager.get + + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 8))) + + // Remove when numExecutorsTarget is the same as the current number of executors + assert(addExecutors(manager) === 1) + assert(addExecutors(manager) === 2) + (1 to 8).map { i => createTaskInfo(i, i, s"$i") }.foreach { + info => post(sc.listenerBus, SparkListenerTaskStart(0, 0, info)) } + assert(executorIds(manager).size === 8) + assert(numExecutorsTarget(manager) === 8) + assert(maxNumExecutorsNeeded(manager) == 8) + assert(!removeExecutor(manager, "1")) // won't work since numExecutorsTarget == numExecutors + + // Remove executors when numExecutorsTarget is lower than current number of executors + (1 to 3).map { i => createTaskInfo(i, i, s"$i") }.foreach { info => + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, Success, info, null)) + } + adjustRequestedExecutors(manager) + assert(executorIds(manager).size === 8) + assert(numExecutorsTarget(manager) === 5) + assert(maxNumExecutorsNeeded(manager) == 5) + assert(removeExecutor(manager, "1")) + assert(removeExecutors(manager, Seq("2", "3"))=== Seq("2", "3")) + onExecutorRemoved(manager, "1") + onExecutorRemoved(manager, "2") + onExecutorRemoved(manager, "3") + + // numExecutorsTarget is lower than minNumExecutors + post(sc.listenerBus, + SparkListenerTaskEnd(0, 0, null, Success, createTaskInfo(4, 4, "4"), null)) + assert(executorIds(manager).size === 5) + assert(numExecutorsTarget(manager) === 5) + assert(maxNumExecutorsNeeded(manager) == 4) + assert(!removeExecutor(manager, "4")) // lower limit + assert(addExecutors(manager) === 0) // upper limit + } + test ("interleaving add and remove") { - sc = createSparkContext(5, 10, 5) + sc = createSparkContext(5, 12, 5) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Add a few executors assert(addExecutors(manager) === 1) @@ -331,52 +410,59 @@ class ExecutorAllocationManagerSuite onExecutorAdded(manager, "7") onExecutorAdded(manager, "8") assert(executorIds(manager).size === 8) + assert(numExecutorsTarget(manager) === 8) - // Remove until limit - assert(removeExecutor(manager, "1")) - assert(removeExecutors(manager, Seq("2", "3")) === Seq("2", "3")) - assert(!removeExecutor(manager, "4")) // lower limit reached - assert(!removeExecutor(manager, "5")) - onExecutorRemoved(manager, "1") - onExecutorRemoved(manager, "2") - onExecutorRemoved(manager, "3") - assert(executorIds(manager).size === 5) - // Add until limit - assert(addExecutors(manager) === 2) // upper limit reached - assert(addExecutors(manager) === 0) - assert(!removeExecutor(manager, "4")) // still at lower limit - assert((manager, Seq("5")) !== Seq("5")) + // Remove when numTargetExecutors is equal to the current number of executors + assert(!removeExecutor(manager, "1")) + assert(removeExecutors(manager, Seq("2", "3")) !== Seq("2", "3")) + + // Remove until limit onExecutorAdded(manager, "9") onExecutorAdded(manager, "10") onExecutorAdded(manager, "11") onExecutorAdded(manager, "12") - onExecutorAdded(manager, "13") - assert(executorIds(manager).size === 10) + assert(executorIds(manager).size === 12) + assert(numExecutorsTarget(manager) === 8) - // Remove succeeds again, now that we are no longer at the lower limit - assert(removeExecutors(manager, Seq("4", "5", "6")) === Seq("4", "5", "6")) - assert(removeExecutor(manager, "7")) - assert(executorIds(manager).size === 10) - assert(addExecutors(manager) === 0) + assert(removeExecutor(manager, "1")) + assert(removeExecutors(manager, Seq("2", "3", "4")) === Seq("2", "3", "4")) + assert(!removeExecutor(manager, "5")) // lower limit reached + assert(!removeExecutor(manager, "6")) + onExecutorRemoved(manager, "1") + onExecutorRemoved(manager, "2") + onExecutorRemoved(manager, "3") onExecutorRemoved(manager, "4") - onExecutorRemoved(manager, "5") assert(executorIds(manager).size === 8) - // Number of executors pending restarts at 1 - assert(numExecutorsToAdd(manager) === 1) - assert(addExecutors(manager) === 0) - assert(executorIds(manager).size === 8) - onExecutorRemoved(manager, "6") - onExecutorRemoved(manager, "7") + // Add until limit + assert(!removeExecutor(manager, "7")) // still at lower limit + assert((manager, Seq("8")) !== Seq("8")) + onExecutorAdded(manager, "13") onExecutorAdded(manager, "14") onExecutorAdded(manager, "15") - assert(executorIds(manager).size === 8) - assert(addExecutors(manager) === 0) // still at upper limit onExecutorAdded(manager, "16") + assert(executorIds(manager).size === 12) + + // Remove succeeds again, now that we are no longer at the lower limit + assert(removeExecutors(manager, Seq("5", "6", "7")) === Seq("5", "6", "7")) + assert(removeExecutor(manager, "8")) + assert(executorIds(manager).size === 12) + onExecutorRemoved(manager, "5") + onExecutorRemoved(manager, "6") + assert(executorIds(manager).size === 10) + assert(numExecutorsToAdd(manager) === 4) + onExecutorRemoved(manager, "9") + onExecutorRemoved(manager, "10") + assert(addExecutors(manager) === 4) // at upper limit onExecutorAdded(manager, "17") + onExecutorAdded(manager, "18") assert(executorIds(manager).size === 10) - assert(numExecutorsTarget(manager) === 10) + assert(addExecutors(manager) === 0) // still at upper limit + onExecutorAdded(manager, "19") + onExecutorAdded(manager, "20") + assert(executorIds(manager).size === 12) + assert(numExecutorsTarget(manager) === 12) } test("starting/canceling add timer") { @@ -489,7 +575,7 @@ class ExecutorAllocationManagerSuite val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) @@ -602,26 +688,26 @@ class ExecutorAllocationManagerSuite // Starting a stage should start the add timer val numTasks = 10 - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, numTasks))) assert(addTime(manager) !== NOT_SET) // Starting a subset of the tasks should not cancel the add timer val taskInfos = (0 to numTasks - 1).map { i => createTaskInfo(i, i, "executor-1") } - taskInfos.tail.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, info)) } + taskInfos.tail.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(0, 0, info)) } assert(addTime(manager) !== NOT_SET) // Starting all remaining tasks should cancel the add timer - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfos.head)) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfos.head)) assert(addTime(manager) === NOT_SET) // Start two different stages // The add timer should be canceled only if all tasks in both stages start running - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, numTasks))) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, numTasks))) assert(addTime(manager) !== NOT_SET) - taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, info)) } + taskInfos.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(1, 0, info)) } assert(addTime(manager) !== NOT_SET) - taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, info)) } + taskInfos.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(2, 0, info)) } assert(addTime(manager) === NOT_SET) } @@ -635,22 +721,22 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).size === 5) // Starting a task cancel the remove timer for that executor - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(1, 1, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(2, 2, "executor-2"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(1, 1, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(2, 2, "executor-2"))) assert(removeTimes(manager).size === 3) assert(!removeTimes(manager).contains("executor-1")) assert(!removeTimes(manager).contains("executor-2")) // Finishing all tasks running on an executor should start the remove timer for that executor - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(0, 0, "executor-1"), new TaskMetrics)) - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(2, 2, "executor-2"), new TaskMetrics)) assert(removeTimes(manager).size === 4) assert(!removeTimes(manager).contains("executor-1")) // executor-1 has not finished yet assert(removeTimes(manager).contains("executor-2")) - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(1, 1, "executor-1"), new TaskMetrics)) assert(removeTimes(manager).size === 5) assert(removeTimes(manager).contains("executor-1")) // executor-1 has now finished @@ -663,13 +749,13 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).isEmpty) // New executors have registered - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(removeTimes(manager).contains("executor-1")) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-2", new ExecutorInfo("host2", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) @@ -677,14 +763,14 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).contains("executor-2")) // Existing executors have disconnected - sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-1", "")) + post(sc.listenerBus, SparkListenerExecutorRemoved(0L, "executor-1", "")) assert(executorIds(manager).size === 1) assert(!executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(!removeTimes(manager).contains("executor-1")) // Unknown executor has disconnected - sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-3", "")) + post(sc.listenerBus, SparkListenerExecutorRemoved(0L, "executor-3", "")) assert(executorIds(manager).size === 1) assert(removeTimes(manager).size === 1) } @@ -695,8 +781,8 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) @@ -708,15 +794,15 @@ class ExecutorAllocationManagerSuite val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-2", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) @@ -729,7 +815,7 @@ class ExecutorAllocationManagerSuite sc = createSparkContext(0, 100000, 0) val manager = sc.executorAllocationManager.get val stage1 = createStageInfo(0, 1000) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stage1)) + post(sc.listenerBus, SparkListenerStageSubmitted(stage1)) assert(addExecutors(manager) === 1) assert(addExecutors(manager) === 2) @@ -740,12 +826,12 @@ class ExecutorAllocationManagerSuite onExecutorAdded(manager, s"executor-$i") } assert(executorIds(manager).size === 15) - sc.listenerBus.postToAll(SparkListenerStageCompleted(stage1)) + post(sc.listenerBus, SparkListenerStageCompleted(stage1)) adjustRequestedExecutors(manager) assert(numExecutorsTarget(manager) === 0) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 1000))) addExecutors(manager) assert(numExecutorsTarget(manager) === 16) } @@ -762,7 +848,7 @@ class ExecutorAllocationManagerSuite // Verify whether the initial number of executors is kept with no pending tasks assert(numExecutorsTarget(manager) === 3) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 2))) clock.advance(100L) assert(maxNumExecutorsNeeded(manager) === 2) @@ -812,7 +898,7 @@ class ExecutorAllocationManagerSuite Seq.empty ) val stageInfo1 = createStageInfo(1, 5, localityPreferences1) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + post(sc.listenerBus, SparkListenerStageSubmitted(stageInfo1)) assert(localityAwareTasks(manager) === 3) assert(hostToLocalTaskCount(manager) === @@ -824,13 +910,13 @@ class ExecutorAllocationManagerSuite Seq.empty ) val stageInfo2 = createStageInfo(2, 3, localityPreferences2) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo2)) + post(sc.listenerBus, SparkListenerStageSubmitted(stageInfo2)) assert(localityAwareTasks(manager) === 5) assert(hostToLocalTaskCount(manager) === Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2)) - sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo1)) + post(sc.listenerBus, SparkListenerStageCompleted(stageInfo1)) assert(localityAwareTasks(manager) === 2) assert(hostToLocalTaskCount(manager) === Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) @@ -841,16 +927,16 @@ class ExecutorAllocationManagerSuite val manager = sc.executorAllocationManager.get assert(maxNumExecutorsNeeded(manager) === 0) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1))) assert(maxNumExecutorsNeeded(manager) === 1) val taskInfo = createTaskInfo(1, 1, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfo)) assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. val taskEndReason = ExceptionFailure(null, null, null, null, None) - sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } @@ -862,7 +948,7 @@ class ExecutorAllocationManagerSuite // Allocation manager is reset when adding executor requests are sent without reporting back // executor added. - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 10))) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 2) @@ -877,7 +963,7 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager) === Set.empty) // Allocation manager is reset when executors are added. - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 10))) addExecutors(manager) addExecutors(manager) @@ -915,12 +1001,17 @@ class ExecutorAllocationManagerSuite onExecutorAdded(manager, "third") onExecutorAdded(manager, "fourth") onExecutorAdded(manager, "fifth") - assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + onExecutorAdded(manager, "sixth") + onExecutorAdded(manager, "seventh") + onExecutorAdded(manager, "eighth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth", + "sixth", "seventh", "eighth")) removeExecutor(manager, "first") removeExecutors(manager, Seq("second", "third")) assert(executorsPendingToRemove(manager) === Set("first", "second", "third")) - assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth", + "sixth", "seventh", "eighth")) // Cluster manager lost will make all the live executors lost, so here simulate this behavior @@ -980,10 +1071,15 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { taskLocalityPreferences = taskLocalityPreferences) } - private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { - new TaskInfo(taskId, taskIndex, 0, 0, executorId, "", TaskLocality.ANY, speculative = false) + private def createTaskInfo( + taskId: Int, + taskIndex: Int, + executorId: String, + speculative: Boolean = false): TaskInfo = { + new TaskInfo(taskId, taskIndex, 0, 0, executorId, "", TaskLocality.ANY, speculative) } + /* ------------------------------------------------------- * | Helper methods for accessing private methods and fields | * ------------------------------------------------------- */ @@ -1010,6 +1106,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy) private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks) private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount) + private val _onSpeculativeTaskSubmitted = PrivateMethod[Unit]('onSpeculativeTaskSubmitted) private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { manager invokePrivate _numExecutorsToAdd() @@ -1085,6 +1182,10 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _onExecutorBusy(id) } + private def onSpeculativeTaskSubmitted(manager: ExecutorAllocationManager, id: String) : Unit = { + manager invokePrivate _onSpeculativeTaskSubmitted(id) + } + private def localityAwareTasks(manager: ExecutorAllocationManager): Int = { manager invokePrivate _localityAwareTasks() } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 5be0121db58a..02728180ac82 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -113,11 +113,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { val normalFile = new File(normalDir, "part-00000") val normalContent = sc.sequenceFile[String, String](normalDir).collect - assert(normalContent === Array.fill(100)("abc", "abc")) + assert(normalContent === Array.fill(100)(("abc", "abc"))) val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect - assert(compressedContent === Array.fill(100)("abc", "abc")) + assert(compressedContent === Array.fill(100)(("abc", "abc"))) assert(compressedFile.length < normalFile.length) } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 99150a1430d9..8a77aea75a99 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ import scala.concurrent.Future +import scala.concurrent.duration._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index bb24c6ce4d33..ebd826b0ba2f 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.{any, isA} +import org.mockito.Matchers.any import org.mockito.Mockito._ +import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} @@ -138,21 +139,21 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) + val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - masterTracker.incrementEpoch() + assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } @@ -175,7 +176,8 @@ class MapOutputTrackerSuite extends SparkFunSuite { val masterTracker = newTrackerMaster(newConf) val rpcEnv = createRpcEnv("spark") val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) - rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) + masterTracker.trackerEndpoint = + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) @@ -190,7 +192,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { verify(rpcCallContext, timeout(30000)).reply(any()) assert(0 == masterTracker.getNumCachedSerializedBroadcast) -// masterTracker.stop() // this throws an exception + masterTracker.stop() rpcEnv.shutdown() } @@ -245,8 +247,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize // needs TorrentBroadcast so need a SparkContext - val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf) - try { + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] val rpcEnv = sc.env.rpcEnv val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) @@ -271,9 +272,6 @@ class MapOutputTrackerSuite extends SparkFunSuite { assert(1 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.unregisterShuffle(20) assert(0 == masterTracker.getNumCachedSerializedBroadcast) - - } finally { - LocalSparkContext.stop(sc) } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 34c017806fe1..dfe4c25670ce 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -253,6 +253,12 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva // Add other tests here for classes that should be able to handle empty partitions correctly } + + test("Number of elements in RDD is less than number of partitions") { + val rdd = sc.parallelize(1 to 3).map(x => (x, x)) + val partitioner = new RangePartitioner(22, rdd) + assert(partitioner.numPartitions === 3) + } } diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 6fc7cea6ee94..8eabc2b3cb95 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -22,6 +22,8 @@ import javax.net.ssl.SSLContext import org.scalatest.BeforeAndAfterAll +import org.apache.spark.util.SparkConfWithEnv + class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { test("test resolving property file as spark conf ") { @@ -133,4 +135,18 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(opts.enabledAlgorithms === Set("ABC", "DEF")) } + test("variable substitution") { + val conf = new SparkConfWithEnv(Map( + "ENV1" -> "val1", + "ENV2" -> "val2")) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", "${env:ENV1}") + conf.set("spark.ssl.trustStore", "${env:ENV2}") + + val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + assert(opts.keyStore === Some(new File("val1"))) + assert(opts.trustStore === Some(new File("val2"))) + } + } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 58b865969f51..3931d53b4ae0 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter -import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId} +import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} import org.apache.spark.util.{MutablePair, Utils} abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -277,7 +277,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // Delete one of the local shuffle blocks. val hashFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleBlockId(0, 0, 0)) val sortFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleDataBlockId(0, 0, 0)) - assert(hashFile.exists() || sortFile.exists()) + val indexFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleIndexBlockId(0, 0, 0)) + assert(hashFile.exists() || (sortFile.exists() && indexFile.exists())) if (hashFile.exists()) { hashFile.delete() @@ -285,11 +286,36 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC if (sortFile.exists()) { sortFile.delete() } + if (indexFile.exists()) { + indexFile.delete() + } // This count should retry the execution of the previous stage and rerun shuffle. rdd.count() } + test("cannot find its local shuffle file if no execution of the stage and rerun shuffle") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 10, 1).map((_, 1)).reduceByKey(_ + _) + + // Cannot find one of the local shuffle blocks. + val hashFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleBlockId(0, 0, 0)) + val sortFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleDataBlockId(0, 0, 0)) + val indexFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleIndexBlockId(0, 0, 0)) + assert(!hashFile.exists() && !sortFile.exists() && !indexFile.exists()) + + rdd.count() + + // Can find one of the local shuffle blocks. + val hashExistsFile = sc.env.blockManager.diskBlockManager + .getFile(new ShuffleBlockId(0, 0, 0)) + val sortExistsFile = sc.env.blockManager.diskBlockManager + .getFile(new ShuffleDataBlockId(0, 0, 0)) + val indexExistsFile = sc.env.blockManager.diskBlockManager + .getFile(new ShuffleIndexBlockId(0, 0, 0)) + assert(hashExistsFile.exists() || (sortExistsFile.exists() && indexExistsFile.exists())) + } + test("metrics for shuffle without aggregation") { sc = new SparkContext("local", "test", conf.clone()) val numRecords = 10000 @@ -333,6 +359,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, @@ -367,7 +394,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => - mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 7a897c2b4698..c0126e41ff7f 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -38,6 +38,10 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { super.beforeAll() + // Once 'spark.local.dir' is set, it is cached. Unless this is manually cleared + // before/after a test, it could return the same directory even if this property + // is configured. + Utils.clearLocalRootDirs() conf.set("spark.shuffle.manager", "sort") } @@ -50,6 +54,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { override def afterEach(): Unit = { try { Utils.deleteRecursively(tempDir) + Utils.clearLocalRootDirs() } finally { super.afterEach() } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 7e26139a2bea..0ed5f26863da 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import scala.concurrent.duration._ -import scala.concurrent.Await import com.google.common.io.Files import org.apache.hadoop.conf.Configuration @@ -31,11 +30,11 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} -import org.scalatest.concurrent.Eventually import org.scalatest.Matchers._ +import org.scalatest.concurrent.Eventually import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -301,13 +300,13 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) sc.addJar(tmpJar.getAbsolutePath) - // Invaid jar path will only print the error log, will not add to file server. + // Invalid jar path will only print the error log, will not add to file server. sc.addJar("dummy.jar") sc.addJar("") sc.addJar(tmpDir.getAbsolutePath) - sc.listJars().size should be (1) - sc.listJars().head should include (tmpJar.getName) + assert(sc.listJars().size == 1) + assert(sc.listJars().head.contains(tmpJar.getName)) } test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { @@ -315,7 +314,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) sc.cancelJobGroup("nonExistGroupId") - Await.ready(future, Duration(2, TimeUnit.SECONDS)) + ThreadUtils.awaitReady(future, Duration(2, TimeUnit.SECONDS)) // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause // SparkContext to shutdown, so the following assertion will fail. @@ -601,6 +600,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu val fs = new DebugFilesystem() fs.initialize(new URI("file:///"), new Configuration()) val file = File.createTempFile("SPARK19446", "temp") + file.deleteOnExit() Files.write(Array.ofDim[Byte](1000), file) val path = new Path("file:///" + file.getCanonicalPath) val stream = fs.open(path) diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index 09e21646ee74..bc3f58cf2a35 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.TimeLimits._ import org.scalatest.time.{Millis, Span} class UnpersistSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 9c13c15281a4..55a541d60ea3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -39,7 +39,7 @@ private[deploy] object DeployTestUtils { } def createDriverCommand(): Command = new Command( - "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + "org.apache.spark.FakeClass", Seq("WORKER_URL", "USER_JAR", "mainClass"), Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) @@ -47,7 +47,7 @@ private[deploy] object DeployTestUtils { new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", - createDriverDesc(), new Date()) + createDriverDesc(), JsonConstants.submitDate) def createWorkerInfo(): WorkerInfo = { val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, "http://publicAddress:80") diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index f50cb38311db..42b8cde65039 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -243,16 +243,22 @@ private[deploy] object IvyTestUtils { withManifest: Option[Manifest] = None): File = { val jarFile = new File(dir, artifactName(artifact, useIvyLayout)) val jarFileStream = new FileOutputStream(jarFile) - val manifest = withManifest.getOrElse { - val mani = new Manifest() + val manifest: Manifest = withManifest.getOrElse { if (withR) { + val mani = new Manifest() val attr = mani.getMainAttributes attr.put(Name.MANIFEST_VERSION, "1.0") attr.put(new Name("Spark-HasRPackage"), "true") + mani + } else { + null } - mani } - val jarStream = new JarOutputStream(jarFileStream, manifest) + val jarStream = if (manifest != null) { + new JarOutputStream(jarFileStream, manifest) + } else { + new JarOutputStream(jarFileStream) + } for (file <- files) { val jarEntry = new JarEntry(file._1) diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 7093dad05c5f..1903130cb694 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -104,8 +104,8 @@ object JsonConstants { val submitDate = new Date(123456789) val appInfoJsonStr = """ - |{"starttime":3,"id":"id","name":"name", - |"cores":4,"user":"%s", + |{"id":"id","starttime":3,"name":"name", + |"cores":0,"user":"%s", |"memoryperslave":1234,"submitdate":"%s", |"state":"WAITING","duration":%d} """.format(System.getProperty("user.name", ""), @@ -134,19 +134,24 @@ object JsonConstants { val driverInfoJsonStr = """ - |{"id":"driver-3","starttime":"3","state":"SUBMITTED","cores":3,"memory":100} - """.stripMargin + |{"id":"driver-3","starttime":"3", + |"state":"SUBMITTED","cores":3,"memory":100, + |"submitdate":"%s","worker":"None", + |"mainclass":"mainClass"} + """.format(submitDate.toString).stripMargin val masterStateJsonStr = """ |{"url":"spark://host:8080", |"workers":[%s,%s], + |"aliveworkers":2, |"cores":8,"coresused":0,"memory":2468,"memoryused":0, |"activeapps":[%s],"completedapps":[], |"activedrivers":[%s], + |"completeddrivers":[%s], |"status":"ALIVE"} """.format(workerInfoJsonStr, workerInfoJsonStr, - appInfoJsonStr, driverInfoJsonStr).stripMargin + appInfoJsonStr, driverInfoJsonStr, driverInfoJsonStr).stripMargin val workerStateJsonStr = """ diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 005587051b6a..32dd3ecc2f02 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -133,6 +133,17 @@ class RPackageUtilsSuite } } + test("jars without manifest return false") { + IvyTestUtils.withRepository(main, None, None) { repo => + val jar = IvyTestUtils.packJar(new File(new URI(repo)), dep1, Nil, + useIvyLayout = false, withR = false, None) + Utils.tryWithResource(new JarFile(jar)) { jarFile => + assert(jarFile.getManifest == null, "jar file should have null manifest") + assert(!RPackageUtils.checkManifestForR(jarFile), "null manifest should return false") + } + } + } + test("SparkR zipping works properly") { val tempDir = Files.createTempDir() Utils.tryWithSafeFinally { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a43839a8815f..ad801bf8519a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -18,28 +18,32 @@ package org.apache.spark.deploy import java.io._ +import java.net.URI import java.nio.charset.StandardCharsets +import java.nio.file.Files +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.io.Source import com.google.common.io.ByteStreams -import org.apache.hadoop.fs.Path +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path} import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.TimeLimits import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate -import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging -import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.internal.config._ import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} - trait TestPrematureExit { suite: SparkFunSuite => @@ -93,7 +97,7 @@ class SparkSubmitSuite with Matchers with BeforeAndAfterEach with ResetSystemProperties - with Timeouts + with TimeLimits with TestPrematureExit { override def beforeEach() { @@ -474,6 +478,26 @@ class SparkSubmitSuite } } + test("includes jars passed through spark.jars.packages and spark.jars.repositories") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val main = MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") + // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. + IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.jars.packages=my.great.lib:mylib:0.1,my.great.dep:mylib:0.1", + "--conf", s"spark.jars.repositories=$repo", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + unusedJar.toString, + "my.great.lib.MyLib", "my.great.dep.MyLib") + runSparkSubmit(args) + } + } + // TODO(SPARK-9603): Building a package is flaky on Jenkins Maven builds. // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { @@ -482,8 +506,8 @@ class SparkSubmitSuite assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val rScriptDir = - Seq(sparkHome, "R", "pkg", "inst", "tests", "packageInAJarTest.R").mkString(File.separator) + val rScriptDir = Seq( + sparkHome, "R", "pkg", "tests", "fulltests", "packageInAJarTest.R").mkString(File.separator) assert(new File(rScriptDir).exists) IvyTestUtils.withRepository(main, None, None, withR = true) { repo => val args = Seq( @@ -504,7 +528,7 @@ class SparkSubmitSuite // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = - Seq(sparkHome, "R", "pkg", "inst", "tests", "testthat", "jarTest.R").mkString(File.separator) + Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) assert(new File(rScriptDir).exists) // compile a small jar containing a class that will be called from R code. @@ -535,7 +559,7 @@ class SparkSubmitSuite test("resolves command line argument paths correctly") { val jars = "/jar1,/jar2" // --jars - val files = "hdfs:/file1,file2" // --files + val files = "local:/file1,file2" // --files val archives = "file:/archive1,archive2" // --archives val pyFiles = "py-file1,py-file2" // --py-files @@ -587,7 +611,7 @@ class SparkSubmitSuite test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars - val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files + val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files val archives = "file:/archive1,archive2" // spark.yarn.dist.archives val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles @@ -703,8 +727,241 @@ class SparkSubmitSuite Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) } + + test("support glob path") { + val tmpJarDir = Utils.createTempDir() + val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) + val jar2 = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpJarDir) + + val tmpFileDir = Utils.createTempDir() + val file1 = File.createTempFile("tmpFile1", "", tmpFileDir) + val file2 = File.createTempFile("tmpFile2", "", tmpFileDir) + + val tmpPyFileDir = Utils.createTempDir() + val pyFile1 = File.createTempFile("tmpPy1", ".py", tmpPyFileDir) + val pyFile2 = File.createTempFile("tmpPy2", ".egg", tmpPyFileDir) + + val tmpArchiveDir = Utils.createTempDir() + val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) + val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"${tmpJarDir.getAbsolutePath}/*.jar", + "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", + "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", + "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", + jar2.toString) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + sysProps("spark.yarn.dist.jars").split(",").toSet should be + (Set(jar1.toURI.toString, jar2.toURI.toString)) + sysProps("spark.yarn.dist.files").split(",").toSet should be + (Set(file1.toURI.toString, file2.toURI.toString)) + sysProps("spark.yarn.dist.pyFiles").split(",").toSet should be + (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) + sysProps("spark.yarn.dist.archives").split(",").toSet should be + (Set(archive1.toURI.toString, archive2.toURI.toString)) + } + // scalastyle:on println + private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = { + if (sourcePath == outputPath) { + return + } + + val sourceUri = new URI(sourcePath) + val outputUri = new URI(outputPath) + assert(outputUri.getScheme === "file") + + // The path and filename are preserved. + assert(outputUri.getPath.endsWith(new Path(sourceUri).getName)) + assert(FileUtils.readFileToString(new File(outputUri.getPath)) === + FileUtils.readFileToString(new File(sourceUri.getPath))) + } + + private def deleteTempOutputFile(outputPath: String): Unit = { + val outputFile = new File(new URI(outputPath).getPath) + if (outputFile.exists) { + outputFile.delete() + } + } + + test("downloadFile - invalid url") { + val sparkConf = new SparkConf(false) + intercept[IOException] { + DependencyUtils.downloadFile( + "abc:/my/file", Utils.createTempDir(), sparkConf, new Configuration(), + new SecurityManager(sparkConf)) + } + } + + test("downloadFile - file doesn't exist") { + val sparkConf = new SparkConf(false) + val hadoopConf = new Configuration() + val tmpDir = Utils.createTempDir() + updateConfWithFakeS3Fs(hadoopConf) + intercept[FileNotFoundException] { + DependencyUtils.downloadFile("s3a:/no/such/file", tmpDir, sparkConf, hadoopConf, + new SecurityManager(sparkConf)) + } + } + + test("downloadFile does not download local file") { + val sparkConf = new SparkConf(false) + val secMgr = new SecurityManager(sparkConf) + // empty path is considered as local file. + val tmpDir = Files.createTempDirectory("tmp").toFile + assert(DependencyUtils.downloadFile("", tmpDir, sparkConf, new Configuration(), secMgr) === "") + assert(DependencyUtils.downloadFile("/local/file", tmpDir, sparkConf, new Configuration(), + secMgr) === "/local/file") + } + + test("download one file to local") { + val sparkConf = new SparkConf(false) + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + val tmpDir = Files.createTempDirectory("tmp").toFile + updateConfWithFakeS3Fs(hadoopConf) + val sourcePath = s"s3a://${jarFile.toURI.getPath}" + val outputPath = DependencyUtils.downloadFile(sourcePath, tmpDir, sparkConf, hadoopConf, + new SecurityManager(sparkConf)) + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + + test("download list of files to local") { + val sparkConf = new SparkConf(false) + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + val tmpDir = Files.createTempDirectory("tmp").toFile + updateConfWithFakeS3Fs(hadoopConf) + val sourcePaths = Seq("/local/file", s"s3a://${jarFile.toURI.getPath}") + val outputPaths = DependencyUtils + .downloadFileList(sourcePaths.mkString(","), tmpDir, sparkConf, hadoopConf, + new SecurityManager(sparkConf)) + .split(",") + + assert(outputPaths.length === sourcePaths.length) + sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) => + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + } + + test("Avoid re-upload remote resources in yarn client mode") { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + + val tmpDir = Utils.createTempDir() + val file = File.createTempFile("tmpFile", "", tmpDir) + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) + val tmpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpJarPath = s"s3a://${new File(tmpJar.toURI).getAbsolutePath}" + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", tmpJarPath, + "--files", s"s3a://${file.getAbsolutePath}", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + s"s3a://$mainResource" + ) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + + // All the resources should still be remote paths, so that YARN client will not upload again. + sysProps("spark.yarn.dist.jars") should be (tmpJarPath) + sysProps("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") + sysProps("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") + + // Local repl jars should be a local path. + sysProps("spark.repl.local.jars") should (startWith("file:")) + + // local py files should not be a URI format. + sysProps("spark.submit.pyFiles") should (startWith("/")) + } + + test("download remote resource if it is not supported by yarn service") { + testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + } + + test("avoid downloading remote resource if it is supported by yarn service") { + testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + } + + test("force download from blacklisted schemes") { + testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + } + + private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, + supportMockHttpFs: Boolean): Unit = { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + if (supportMockHttpFs) { + hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) + hadoopConf.set("fs.http.impl.disable.cache", "true") + } + + val tmpDir = Utils.createTempDir() + val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) + val tmpS3Jar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpS3JarPath = s"s3a://${new File(tmpS3Jar.toURI).getAbsolutePath}" + val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", + s"s3a://$mainResource" + ) ++ ( + if (isHttpSchemeBlacklisted) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") + } else { + Nil + } + ) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + + val jars = sysProps("spark.yarn.dist.jars").split(",").toSet + + // The URI of remote S3 resource should still be remote. + assert(jars.contains(tmpS3JarPath)) + + if (supportMockHttpFs) { + // If Http FS is supported by yarn service, the URI of remote http resource should + // still be remote. + assert(jars.contains(tmpHttpJarPath)) + } else { + // If Http FS is not supported by yarn service, or http scheme is configured to be force + // downloading, the URI of remote http resource should be changed to a local one. + val jarName = new File(tmpHttpJar.toURI).getName + val localHttpJar = jars.filter(_.contains(jarName)) + localHttpJar.size should be(1) + localHttpJar.head should startWith("file:") + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -744,6 +1001,11 @@ class SparkSubmitSuite Utils.deleteRecursively(tmpDir) } } + + private def updateConfWithFakeS3Fs(conf: Configuration): Unit = { + conf.set("fs.s3a.impl", classOf[TestFileSystem].getCanonicalName) + conf.set("fs.s3a.impl.disable.cache", "true") + } } object JarCreationTest extends Logging { @@ -807,3 +1069,33 @@ object UserClasspathFirstTest { } } } + +class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { + private def local(path: Path): Path = { + // Ignore the scheme for testing. + new Path(path.toUri.getPath) + } + + private def toRemote(status: FileStatus): FileStatus = { + val path = s"s3a://${status.getPath.toUri.getPath}" + status.setPath(new Path(path)) + status + } + + override def isFile(path: Path): Boolean = super.isFile(local(path)) + + override def globStatus(pathPattern: Path): Array[FileStatus] = { + val newPath = new Path(pathPattern.toUri.getPath) + super.globStatus(newPath).map(toRemote) + } + + override def listStatus(path: Path): Array[FileStatus] = { + super.listStatus(local(path)).map(toRemote) + } + + override def copyToLocalFile(src: Path, dst: Path): Unit = { + super.copyToLocalFile(local(src), dst) + } + + override def open(path: Path): FSDataInputStream = super.open(local(path)) +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 266c9d33b5a9..eb8c203ae775 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -83,18 +83,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { val resolver = settings.getDefaultResolver.asInstanceOf[ChainResolver] assert(resolver.getResolvers.size() === 4) val expected = repos.split(",").map(r => s"$r/") - resolver.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => - if (1 < i && i < 3) { - assert(resolver.getName === s"repo-$i") - assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 1)) - } + resolver.getResolvers.toArray.map(_.asInstanceOf[AbstractResolver]).zipWithIndex.foreach { + case (r, i) => + if (1 < i && i < 3) { + assert(r.getName === s"repo-$i") + assert(r.asInstanceOf[IBiblioResolver].getRoot === expected(i - 1)) + } } } test("add dependencies works correctly") { val md = SparkSubmitUtils.getModuleDescriptor - val artifacts = SparkSubmitUtils.extractMavenCoordinates("com.databricks:spark-csv_2.10:0.1," + - "com.databricks:spark-avro_2.10:0.1") + val artifacts = SparkSubmitUtils.extractMavenCoordinates("com.databricks:spark-csv_2.11:0.1," + + "com.databricks:spark-avro_2.11:0.1") SparkSubmitUtils.addDependenciesToIvy(md, artifacts, "default") assert(md.getDependencies.length === 2) @@ -187,19 +188,16 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("neglects Spark and Spark's dependencies") { - val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", - "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") - - val coordinates = - components.map(comp => s"org.apache.spark:spark-${comp}2.10:1.2.0").mkString(",") + - ",org.apache.spark:spark-core_fake:1.2.0" + val coordinates = SparkSubmitUtils.IVY_DEFAULT_EXCLUDES + .map(comp => s"org.apache.spark:spark-${comp}2.11:2.1.1") + .mkString(",") + ",org.apache.spark:spark-core_fake:1.2.0" val path = SparkSubmitUtils.resolveMavenCoordinates( coordinates, SparkSubmitUtils.buildIvySettings(None, None), isTest = true) assert(path === "", "should return empty path") - val main = MavenCoordinate("org.apache.spark", "spark-streaming-kafka-assembly_2.10", "1.2.0") + val main = MavenCoordinate("org.apache.spark", "spark-streaming-kafka-assembly_2.11", "1.2.0") IvyTestUtils.withRepository(main, None, None) { repo => val files = SparkSubmitUtils.resolveMavenCoordinates( coordinates + "," + main.toString, diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 936639b84578..a1707e6540b3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -214,6 +214,8 @@ class AppClientSuite id: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit = { execRemovedList.add(id) } + + def workerRemoved(workerId: String, host: String, message: String): Unit = {} } /** Create AppClient and supporting objects */ diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala index 7998e3702c12..6e50e8454904 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -33,7 +33,7 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.Matchers -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging @@ -78,7 +78,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar logDebug(s"getAppUI($appId, $attemptId)") getAppUICount += 1 instances.get(CacheKey(appId, attemptId)).map( e => - LoadedAppUI(e.ui, updateProbe(appId, attemptId, e.probeTime))) + LoadedAppUI(e.ui, () => updateProbe(appId, attemptId, e.probeTime))) } override def attachSparkUI( @@ -122,7 +122,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar completed: Boolean, timestamp: Long): Unit = { instances += (CacheKey(appId, attemptId) -> - new CacheEntry(ui, completed, updateProbe(appId, attemptId, timestamp), timestamp)) + new CacheEntry(ui, completed, () => updateProbe(appId, attemptId, timestamp), timestamp)) } /** @@ -177,7 +177,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar ended: Long): SparkUI = { val info = new ApplicationInfo(name, name, Some(1), Some(1), Some(1), Some(64), Seq(new AttemptInfo(attemptId, new Date(started), new Date(ended), - new Date(ended), ended - started, "user", completed))) + new Date(ended), ended - started, "user", completed, org.apache.spark.SPARK_VERSION))) val ui = mock[SparkUI] when(ui.getApplicationInfoList).thenReturn(List(info).iterator) when(ui.getAppName).thenReturn(name) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 456158d41b93..2141934c9264 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.history import java.io._ -import java.net.URI import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} @@ -27,7 +26,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -37,6 +36,7 @@ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.scheduler._ @@ -63,13 +63,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc codec: Option[String] = None): File = { val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId, appAttemptId) - val logPath = new URI(logUri).getPath + ip + val logPath = new Path(logUri).toUri.getPath + ip new File(logPath) } - test("Parse application logs") { + Seq(true, false).foreach { inMemory => + test(s"Parse application logs (inMemory = $inMemory)") { + testAppLogParsing(inMemory) + } + } + + private def testAppLogParsing(inMemory: Boolean) { val clock = new ManualClock(12345678) - val provider = new FsHistoryProvider(createTestConf(), clock) + val provider = new FsHistoryProvider(createTestConf(inMemory = inMemory), clock) // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) @@ -109,7 +115,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc user: String, completed: Boolean): ApplicationHistoryInfo = { ApplicationHistoryInfo(id, name, - List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed))) + List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed, ""))) } // For completed files, lastUpdated would be lastModified time. @@ -173,20 +179,18 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) updateAndCheck(provider) { list => list.size should be (1) - list.head.attempts.head.asInstanceOf[FsApplicationAttemptInfo].logPath should - endWith(EventLoggingListener.IN_PROGRESS) + provider.getAttempt("app1", None).logPath should endWith(EventLoggingListener.IN_PROGRESS) } logFile1.renameTo(newLogFile("app1", None, inProgress = false)) updateAndCheck(provider) { list => list.size should be (1) - list.head.attempts.head.asInstanceOf[FsApplicationAttemptInfo].logPath should not - endWith(EventLoggingListener.IN_PROGRESS) + provider.getAttempt("app1", None).logPath should not endWith(EventLoggingListener.IN_PROGRESS) } } test("Parse logs that application is not started") { - val provider = new FsHistoryProvider((createTestConf())) + val provider = new FsHistoryProvider(createTestConf()) val logFile1 = newLogFile("app1", None, inProgress = true) writeFile(logFile1, true, None, @@ -343,17 +347,23 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc provider.checkForLogs() // This should not trigger any cleanup - updateAndCheck(provider)(list => list.size should be(2)) + updateAndCheck(provider) { list => + list.size should be(2) + } // Should trigger cleanup for first file but not second one clock.setTime(firstFileModifiedTime + maxAge + 1) - updateAndCheck(provider)(list => list.size should be(1)) + updateAndCheck(provider) { list => + list.size should be(1) + } assert(!log1.exists()) assert(log2.exists()) // Should cleanup the second file as well. clock.setTime(secondFileModifiedTime + maxAge + 1) - updateAndCheck(provider)(list => list.size should be(0)) + updateAndCheck(provider) { list => + list.size should be(0) + } assert(!log1.exists()) assert(!log2.exists()) } @@ -581,7 +591,34 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc securityManager.checkUIViewPermissions("user4") should be (false) securityManager.checkUIViewPermissions("user5") should be (false) } - } + } + + test("mismatched version discards old listing") { + val conf = createTestConf() + val oldProvider = new FsHistoryProvider(conf) + + val logFile1 = newLogFile("app1", None, inProgress = false) + writeFile(logFile1, true, None, + SparkListenerLogStart("2.3"), + SparkListenerApplicationStart("test", Some("test"), 1L, "test", None), + SparkListenerApplicationEnd(5L) + ) + + updateAndCheck(oldProvider) { list => + list.size should be (1) + } + assert(oldProvider.listing.count(classOf[ApplicationInfoWrapper]) === 1) + + // Manually overwrite the version in the listing db; this should cause the new provider to + // discard all data because the versions don't match. + val meta = new KVStoreMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, + conf.get(LOCAL_STORE_DIR).get) + oldProvider.listing.setMetadata(meta) + oldProvider.stop() + + val mistatchedVersionProvider = new FsHistoryProvider(conf) + assert(mistatchedVersionProvider.listing.count(classOf[ApplicationInfoWrapper]) === 0) + } /** * Asks the provider to check for logs and calls a function to perform checks on the updated @@ -606,7 +643,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc if (isNewFormat) { val newFormatStream = new FileOutputStream(file) Utils.tryWithSafeFinally { - EventLoggingListener.initEventLog(newFormatStream) + EventLoggingListener.initEventLog(newFormatStream, false, null) } { newFormatStream.close() } @@ -624,8 +661,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new FileOutputStream(file).close() } - private def createTestConf(): SparkConf = { - new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + private def createTestConf(inMemory: Boolean = false): SparkConf = { + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + + if (!inMemory) { + conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath()) + } + + conf } private class SafeModeTestProvider(conf: SparkConf, clock: Clock) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 95acb9a54440..c11543a4b3ba 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -39,10 +39,11 @@ import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.scalatest.selenium.WebBrowser import org.apache.spark._ +import org.apache.spark.deploy.history.config._ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.UIData.JobUIData import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -64,6 +65,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers private val logDir = getTestResourcePath("spark-events") private val expRoot = getTestResourceFile("HistoryServerExpectations") + private val storeDir = Utils.createTempDir(namePrefix = "history") private var provider: FsHistoryProvider = null private var server: HistoryServer = null @@ -74,6 +76,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") + .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -87,14 +90,13 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers def stop(): Unit = { server.stop() + server = null } before { - init() - } - - after{ - stop() + if (server == null) { + init() + } } val cases = Seq( @@ -296,6 +298,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") + .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -372,6 +375,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } test("incomplete apps get refreshed") { + server.stop() implicit val webDriver: WebDriver = new HtmlUnitDriver implicit val formats = org.json4s.DefaultFormats @@ -388,6 +392,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.history.fs.update.interval", "1s") .set("spark.eventLog.enabled", "true") .set("spark.history.cache.window", "250ms") + .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) .remove("spark.testing") val provider = new FsHistoryProvider(myConf) val securityManager = HistoryServer.createSecurityManager(myConf) @@ -413,8 +418,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } } - // stop the server with the old config, and start the new one - server.stop() server = new HistoryServer(myConf, provider, securityManager, 18080) server.initialize() server.bind() diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 2127da48ece4..84b3a29b58bf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -19,14 +19,19 @@ package org.apache.spark.deploy.master import java.util.Date import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.{HashMap, HashSet} import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps +import scala.reflect.ClassTag import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} @@ -34,7 +39,52 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.serializer + +object MockWorker { + val counter = new AtomicInteger(10000) +} + +class MockWorker(master: RpcEndpointRef, conf: SparkConf = new SparkConf) extends RpcEndpoint { + val seq = MockWorker.counter.incrementAndGet() + val id = seq.toString + override val rpcEnv: RpcEnv = RpcEnv.create("worker", "localhost", seq, + conf, new SecurityManager(conf)) + var apps = new mutable.HashMap[String, String]() + val driverIdToAppId = new mutable.HashMap[String, String]() + def newDriver(driverId: String): RpcEndpointRef = { + val name = s"driver_${drivers.size}" + rpcEnv.setupEndpoint(name, new RpcEndpoint { + override val rpcEnv: RpcEnv = MockWorker.this.rpcEnv + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId, _) => + apps(appId) = appId + driverIdToAppId(driverId) = appId + } + }) + } + + val appDesc = DeployTestUtils.createAppDesc() + val drivers = mutable.HashSet[String]() + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, _, _) => + masterRef.send(WorkerLatestState(id, Nil, drivers.toSeq)) + case LaunchDriver(driverId, desc) => + drivers += driverId + master.send(RegisterApplication(appDesc, newDriver(driverId))) + case KillDriver(driverId) => + master.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + drivers -= driverId + driverIdToAppId.get(driverId) match { + case Some(appId) => + apps.remove(appId) + master.send(UnregisterApplication(appId)) + case None => + } + driverIdToAppId.remove(driverId) + } +} class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { @@ -134,6 +184,81 @@ class MasterSuite extends SparkFunSuite CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts } + test("master correctly recover the application") { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.deploy.recoveryMode", "CUSTOM") + conf.set("spark.deploy.recoveryMode.factory", + classOf[FakeRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") + + val fakeAppInfo = makeAppInfo(1024) + val fakeWorkerInfo = makeWorkerInfo(8192, 16) + val fakeDriverInfo = new DriverInfo( + startTime = 0, + id = "test_driver", + desc = new DriverDescription( + jarUrl = "", + mem = 1024, + cores = 1, + supervise = false, + command = new Command("", Nil, Map.empty, Nil, Nil, Nil)), + submitDate = new Date()) + + // Build the fake recovery data + FakeRecoveryModeFactory.persistentData.put(s"app_${fakeAppInfo.id}", fakeAppInfo) + FakeRecoveryModeFactory.persistentData.put(s"driver_${fakeDriverInfo.id}", fakeDriverInfo) + FakeRecoveryModeFactory.persistentData.put(s"worker_${fakeWorkerInfo.id}", fakeWorkerInfo) + + var master: Master = null + try { + master = makeMaster(conf) + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + // Wait until Master recover from checkpoint data. + eventually(timeout(5 seconds), interval(100 milliseconds)) { + master.workers.size should be(1) + } + + master.idToApp.keySet should be(Set(fakeAppInfo.id)) + getDrivers(master) should be(Set(fakeDriverInfo)) + master.workers should be(Set(fakeWorkerInfo)) + + // Notify Master about the executor and driver info to make it correctly recovered. + val fakeExecutors = List( + new ExecutorDescription(fakeAppInfo.id, 0, 8, ExecutorState.RUNNING), + new ExecutorDescription(fakeAppInfo.id, 0, 7, ExecutorState.RUNNING)) + + fakeAppInfo.state should be(ApplicationState.UNKNOWN) + fakeWorkerInfo.coresFree should be(16) + fakeWorkerInfo.coresUsed should be(0) + + master.self.send(MasterChangeAcknowledged(fakeAppInfo.id)) + eventually(timeout(1 second), interval(10 milliseconds)) { + // Application state should be WAITING when "MasterChangeAcknowledged" event executed. + fakeAppInfo.state should be(ApplicationState.WAITING) + } + + master.self.send( + WorkerSchedulerStateResponse(fakeWorkerInfo.id, fakeExecutors, Seq(fakeDriverInfo.id))) + + eventually(timeout(5 seconds), interval(100 milliseconds)) { + getState(master) should be(RecoveryState.ALIVE) + } + + // If driver's resource is also counted, free cores should 0 + fakeWorkerInfo.coresFree should be(0) + fakeWorkerInfo.coresUsed should be(16) + // State of application should be RUNNING + fakeAppInfo.state should be(ApplicationState.RUNNING) + } finally { + if (master != null) { + master.rpcEnv.shutdown() + master.rpcEnv.awaitTermination() + master = null + FakeRecoveryModeFactory.persistentData.clear() + } + } + } + test("master/worker web ui available") { implicit val formats = org.json4s.DefaultFormats val conf = new SparkConf() @@ -394,6 +519,9 @@ class MasterSuite extends SparkFunSuite // ========================================== private val _scheduleExecutorsOnWorkers = PrivateMethod[Array[Int]]('scheduleExecutorsOnWorkers) + private val _drivers = PrivateMethod[HashSet[DriverInfo]]('drivers) + private val _state = PrivateMethod[RecoveryState.Value]('state) + private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -412,12 +540,18 @@ class MasterSuite extends SparkFunSuite val desc = new ApplicationDescription( "test", maxCores, memoryPerExecutorMb, null, "", None, None, coresPerExecutor) val appId = System.currentTimeMillis.toString - new ApplicationInfo(0, appId, desc, new Date, null, Int.MaxValue) + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + new ApplicationInfo(0, appId, desc, new Date, endpointRef, Int.MaxValue) } private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { val workerId = System.currentTimeMillis.toString - new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, "http://localhost:80") + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + new WorkerInfo(workerId, "host", 100, cores, memoryMb, endpointRef, "http://localhost:80") } private def scheduleExecutorsOnWorkers( @@ -442,13 +576,20 @@ class MasterSuite extends SparkFunSuite override val rpcEnv: RpcEnv = master.rpcEnv override def receive: PartialFunction[Any, Unit] = { - case KillExecutor(_, appId, execId) => killedExecutors.add(appId, execId) + case KillExecutor(_, appId, execId) => killedExecutors.add((appId, execId)) case KillDriver(driverId) => killedDrivers.add(driverId) } }) - master.self.send( - RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080")) + master.self.send(RegisterWorker( + "1", + "localhost", + 9999, + fakeWorker, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost", 9999))) val executors = (0 until 3).map { i => new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING) } @@ -459,4 +600,137 @@ class MasterSuite extends SparkFunSuite assert(killedDrivers.asScala.toList.sorted === List("0", "1", "2")) } } + + test("SPARK-20529: Master should reply the address received from worker") { + val master = makeMaster() + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + + @volatile var receivedMasterAddress: RpcAddress = null + val fakeWorker = master.rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = master.rpcEnv + + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(_, _, masterAddress) => + receivedMasterAddress = masterAddress + } + }) + + master.self.send(RegisterWorker( + "1", + "localhost", + 9999, + fakeWorker, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost2", 10000))) + + eventually(timeout(10.seconds)) { + assert(receivedMasterAddress === RpcAddress("localhost2", 10000)) + } + } + + test("SPARK-19900: there should be a corresponding driver for the app after relaunching driver") { + val conf = new SparkConf().set("spark.worker.timeout", "1") + val master = makeMaster(conf) + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + val worker1 = new MockWorker(master.self) + worker1.rpcEnv.setupEndpoint("worker", worker1) + val worker1Reg = RegisterWorker( + worker1.id, + "localhost", + 9998, + worker1.self, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost2", 10000)) + master.self.send(worker1Reg) + val driver = DeployTestUtils.createDriverDesc().copy(supervise = true) + master.self.askSync[SubmitDriverResponse](RequestSubmitDriver(driver)) + + eventually(timeout(10.seconds)) { + assert(worker1.apps.nonEmpty) + } + + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.workers(0).state == WorkerState.DEAD) + } + + val worker2 = new MockWorker(master.self) + worker2.rpcEnv.setupEndpoint("worker", worker2) + master.self.send(RegisterWorker( + worker2.id, + "localhost", + 9999, + worker2.self, + 10, + 1024, + "http://localhost:8081", + RpcAddress("localhost", 10001))) + eventually(timeout(10.seconds)) { + assert(worker2.apps.nonEmpty) + } + + master.self.send(worker1Reg) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + + val worker = masterState.workers.filter(w => w.id == worker1.id) + assert(worker.length == 1) + // make sure the `DriverStateChanged` arrives at Master. + assert(worker(0).drivers.isEmpty) + assert(worker1.apps.isEmpty) + assert(worker1.drivers.isEmpty) + assert(worker2.apps.size == 1) + assert(worker2.drivers.size == 1) + assert(masterState.activeDrivers.length == 1) + assert(masterState.activeApps.length == 1) + } + } + + private def getDrivers(master: Master): HashSet[DriverInfo] = { + master.invokePrivate(_drivers()) + } + + private def getState(master: Master): RecoveryState.Value = { + master.invokePrivate(_state()) + } +} + +private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) + extends StandaloneRecoveryModeFactory(conf, ser) { + import FakeRecoveryModeFactory.persistentData + + override def createPersistenceEngine(): PersistenceEngine = new PersistenceEngine { + + override def unpersist(name: String): Unit = { + persistentData.remove(name) + } + + override def persist(name: String, obj: Object): Unit = { + persistentData(name) = obj + } + + override def read[T: ClassTag](prefix: String): Seq[T] = { + persistentData.filter(_._1.startsWith(prefix)).map(_._2.asInstanceOf[T]).toSeq + } + } + + override def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { + new MonarchyLeaderAgent(master) + } +} + +private object FakeRecoveryModeFactory { + val persistentData = new HashMap[String, Object]() } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index dd50e33da30a..70887dc5dd97 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -11,7 +11,7 @@ * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and + * See the License for the specific language governing permissions and * limitations under the License. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala new file mode 100644 index 000000000000..eeffc36070b4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { + private var delegationTokenManager: HadoopDelegationTokenManager = null + private var sparkConf: SparkConf = null + private var hadoopConf: Configuration = null + + override def beforeAll(): Unit = { + super.beforeAll() + + sparkConf = new SparkConf() + hadoopConf = new Configuration() + } + + test("Correctly load default credential providers") { + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("bogus") should be (None) + } + + test("disable hive credential provider") { + sparkConf.set("spark.security.credentials.hive.enabled", "false") + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + } + + test("using deprecated configurations") { + sparkConf.set("spark.yarn.security.tokens.hadoopfs.enabled", "false") + sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false") + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + } + + test("verify no credentials are obtained") { + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess) + val creds = new Credentials() + + // Tokens cannot be obtained from HDFS, Hive, HBase in unit tests. + delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) + val tokens = creds.getAllTokens + tokens.size() should be (0) + } + + test("obtain tokens For HiveMetastore") { + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.kerberos.principal", "bob") + // thrift picks up on port 0 and bails out, without trying to talk to endpoint + hadoopConf.set("hive.metastore.uris", "http://localhost:0") + + val hiveCredentialProvider = new HiveDelegationTokenProvider() + val credentials = new Credentials() + hiveCredentialProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials) + + credentials.getAllTokens.size() should be (0) + } + + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + + val hbaseTokenProvider = new HBaseDelegationTokenProvider() + val creds = new Credentials() + hbaseTokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, creds) + + creds.getAllTokens.size should be (0) + } + + private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { + Set(FileSystem.get(hadoopConf)) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 101a44edd8ee..ce212a751331 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.worker -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.{Command, ExecutorState} @@ -25,7 +25,7 @@ import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorState import org.apache.spark.deploy.master.DriverState import org.apache.spark.rpc.{RpcAddress, RpcEnv} -class WorkerSuite extends SparkFunSuite with Matchers { +class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { import org.apache.spark.deploy.DeployTestUtils._ @@ -34,6 +34,25 @@ class WorkerSuite extends SparkFunSuite with Matchers { } def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) + private var _worker: Worker = _ + + private def makeWorker(conf: SparkConf): Worker = { + assert(_worker === null, "Some Worker's RpcEnv is leaked in tests") + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr) + _worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "Worker", "/tmp", conf, securityMgr) + _worker + } + + after { + if (_worker != null) { + _worker.rpcEnv.shutdown() + _worker.rpcEnv.awaitTermination() + _worker = null + } + } + test("test isUseLocalNodeSSLConfig") { Worker.isUseLocalNodeSSLConfig(cmd("-Dasdf=dfgh")) shouldBe false Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=true")) shouldBe true @@ -65,9 +84,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedExecutors (small number of executors)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedExecutors", 2.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 5) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -91,9 +108,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedExecutors (more executors)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedExecutors", 30.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 50) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -126,9 +141,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedDrivers (small number of drivers)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedDrivers", 2.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 5) { val driverId = s"driverId-$i" @@ -152,9 +165,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedDrivers (more drivers)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedDrivers", 30.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 50) { val driverId = s"driverId-$i" diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index efcad140350b..105a178f2d94 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -25,6 +25,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.Map import scala.concurrent.duration._ +import scala.language.postfixOps import org.mockito.ArgumentCaptor import org.mockito.Matchers.{any, eq => meq} @@ -32,7 +33,7 @@ import org.mockito.Mockito.{inOrder, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -41,7 +42,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.UninterruptibleThread @@ -233,6 +234,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val mockMemoryManager = mock[MemoryManager] when(mockEnv.conf).thenReturn(conf) when(mockEnv.serializer).thenReturn(serializer) + when(mockEnv.serializerManager).thenReturn(mock[SerializerManager]) when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) when(mockEnv.memoryManager).thenReturn(mockMemoryManager) diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index eae26fa742a2..7bcc2fb5231d 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -94,6 +94,8 @@ class TaskMetricsSuite extends SparkFunSuite { sr.setRemoteBytesRead(30L) sr.incRemoteBytesRead(3L) sr.incRemoteBytesRead(3L) + sr.setRemoteBytesReadToDisk(10L) + sr.incRemoteBytesReadToDisk(8L) sr.setLocalBytesRead(400L) sr.setLocalBytesRead(40L) sr.incLocalBytesRead(4L) @@ -110,6 +112,7 @@ class TaskMetricsSuite extends SparkFunSuite { assert(sr.remoteBlocksFetched == 12) assert(sr.localBlocksFetched == 24) assert(sr.remoteBytesRead == 36L) + assert(sr.remoteBytesReadToDisk == 18L) assert(sr.localBytesRead == 48L) assert(sr.fetchWaitTime == 60L) assert(sr.recordsRead == 72L) diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index b72cd8be2420..bf08276dbf97 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -261,4 +261,31 @@ class ConfigEntrySuite extends SparkFunSuite { data = 2 assert(conf.get(iConf) === 2) } + + test("conf entry: alternative keys") { + val conf = new SparkConf() + val iConf = ConfigBuilder(testKey("a")) + .withAlternative(testKey("b")) + .withAlternative(testKey("c")) + .intConf.createWithDefault(0) + + // no key is set, return default value. + assert(conf.get(iConf) === 0) + + // the primary key is set, the alternative keys are not set, return the value of primary key. + conf.set(testKey("a"), "1") + assert(conf.get(iConf) === 1) + + // the primary key and alternative keys are all set, return the value of primary key. + conf.set(testKey("b"), "2") + conf.set(testKey("c"), "3") + assert(conf.get(iConf) === 1) + + // the primary key is not set, (some of) the alternative keys are set, return the value of the + // first alternative key that is set. + conf.remove(testKey("a")) + assert(conf.get(iConf) === 2) + conf.remove(testKey("b")) + assert(conf.get(iConf) === 3) + } } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index eb2b3ffd1509..85eeb5055ae0 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -117,7 +117,7 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft evictBlocksToFreeSpaceCalled.set(numBytesToFree) if (numBytesToFree <= mm.storageMemoryUsed) { // We can evict enough blocks to fulfill the request for space - mm.releaseStorageMemory(numBytesToFree, MemoryMode.ON_HEAP) + mm.releaseStorageMemory(numBytesToFree, mm.tungstenMemoryMode) evictedBlocks += Tuple2(null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L)) numBytesToFree } else { diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index c821054412d7..02b04cdbb2a5 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -303,4 +303,36 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes mm.invokePrivate[Unit](assertInvariants()) } + test("not enough free memory in the storage pool --OFF_HEAP") { + val conf = new SparkConf() + .set("spark.memory.offHeap.size", "1000") + .set("spark.testing.memory", "1000") + .set("spark.memory.offHeap.enabled", "true") + val taskAttemptId = 0L + val mm = UnifiedMemoryManager(conf, numCores = 1) + val ms = makeMemoryStore(mm) + val memoryMode = MemoryMode.OFF_HEAP + + assert(mm.acquireExecutionMemory(400L, taskAttemptId, memoryMode) === 400L) + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 400L) + + // Fail fast + assert(!mm.acquireStorageMemory(dummyBlock, 700L, memoryMode)) + assert(mm.storageMemoryUsed === 0L) + + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) + assert(mm.storageMemoryUsed === 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + + // Borrow 50 from execution memory + assert(mm.acquireStorageMemory(dummyBlock, 450L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 550L) + + // Borrow 50 from execution memory and evict 50 to free space + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 50) + assert(mm.storageMemoryUsed === 600L) + } } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 5d522189a0c2..6f4203da1d86 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { @@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext } assert(bytesRead >= tmpFile.length()) } + + test("input metrics with old Hadoop API in different thread") { + val bytesRead = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 4).mapPartitions { iter => + val buf = new ArrayBuffer[String]() + ThreadUtils.runInNewThread("testThread", false) { + iter.flatMap(_.split(" ")).foreach(buf.append(_)) + } + + buf.iterator + }.count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new Hadoop API in different thread") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).mapPartitions { iter => + val buf = new ArrayBuffer[String]() + ThreadUtils.runInNewThread("testThread", false) { + iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_)) + } + + buf.iterator + }.count() + } + assert(bytesRead >= tmpFile.length()) + } } /** diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala new file mode 100644 index 000000000000..0e21a36071c4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.net.{DatagramPacket, DatagramSocket} +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.Properties +import java.util.concurrent.TimeUnit._ + +import com.codahale.metrics._ + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.metrics.sink.StatsdSink._ + +class StatsdSinkSuite extends SparkFunSuite { + private val securityMgr = new SecurityManager(new SparkConf(false)) + private val defaultProps = Map( + STATSD_KEY_PREFIX -> "spark", + STATSD_KEY_PERIOD -> "1", + STATSD_KEY_UNIT -> "seconds", + STATSD_KEY_HOST -> "127.0.0.1" + ) + private val socketTimeout = 30000 // milliseconds + private val socketBufferSize = 8192 + + private def withSocketAndSink(testCode: (DatagramSocket, StatsdSink) => Any): Unit = { + val socket = new DatagramSocket + socket.setReceiveBufferSize(socketBufferSize) + socket.setSoTimeout(socketTimeout) + val props = new Properties + defaultProps.foreach(e => props.put(e._1, e._2)) + props.put(STATSD_KEY_PORT, socket.getLocalPort.toString) + val registry = new MetricRegistry + val sink = new StatsdSink(props, registry, securityMgr) + try { + testCode(socket, sink) + } finally { + socket.close() + } + } + + test("metrics StatsD sink with Counter") { + withSocketAndSink { (socket, sink) => + val counter = new Counter + counter.inc(12) + sink.registry.register("counter", counter) + sink.report() + + val p = new DatagramPacket(new Array[Byte](socketBufferSize), socketBufferSize) + socket.receive(p) + + val result = new String(p.getData, 0, p.getLength, UTF_8) + assert(result === "spark.counter:12|c", "Counter metric received should match data sent") + } + } + + test("metrics StatsD sink with Gauge") { + withSocketAndSink { (socket, sink) => + val gauge = new Gauge[Double] { + override def getValue: Double = 1.23 + } + sink.registry.register("gauge", gauge) + sink.report() + + val p = new DatagramPacket(new Array[Byte](socketBufferSize), socketBufferSize) + socket.receive(p) + + val result = new String(p.getData, 0, p.getLength, UTF_8) + assert(result === "spark.gauge:1.23|g", "Gauge metric received should match data sent") + } + } + + test("metrics StatsD sink with Histogram") { + withSocketAndSink { (socket, sink) => + val p = new DatagramPacket(new Array[Byte](socketBufferSize), socketBufferSize) + val histogram = new Histogram(new UniformReservoir) + histogram.update(10) + histogram.update(20) + histogram.update(30) + sink.registry.register("histogram", histogram) + sink.report() + + val expectedResults = Set( + "spark.histogram.count:3|g", + "spark.histogram.max:30|ms", + "spark.histogram.mean:20.00|ms", + "spark.histogram.min:10|ms", + "spark.histogram.stddev:10.00|ms", + "spark.histogram.p50:20.00|ms", + "spark.histogram.p75:30.00|ms", + "spark.histogram.p95:30.00|ms", + "spark.histogram.p98:30.00|ms", + "spark.histogram.p99:30.00|ms", + "spark.histogram.p999:30.00|ms" + ) + + (1 to expectedResults.size).foreach { i => + socket.receive(p) + val result = new String(p.getData, 0, p.getLength, UTF_8) + logInfo(s"Received histogram result $i: '$result'") + assert(expectedResults.contains(result), + "Histogram metric received should match data sent") + } + } + } + + test("metrics StatsD sink with Timer") { + withSocketAndSink { (socket, sink) => + val p = new DatagramPacket(new Array[Byte](socketBufferSize), socketBufferSize) + val timer = new Timer() + timer.update(1, SECONDS) + timer.update(2, SECONDS) + timer.update(3, SECONDS) + sink.registry.register("timer", timer) + sink.report() + + val expectedResults = Set( + "spark.timer.max:3000.00|ms", + "spark.timer.mean:2000.00|ms", + "spark.timer.min:1000.00|ms", + "spark.timer.stddev:816.50|ms", + "spark.timer.p50:2000.00|ms", + "spark.timer.p75:3000.00|ms", + "spark.timer.p95:3000.00|ms", + "spark.timer.p98:3000.00|ms", + "spark.timer.p99:3000.00|ms", + "spark.timer.p999:3000.00|ms", + "spark.timer.count:3|g", + "spark.timer.m1_rate:0.00|ms", + "spark.timer.m5_rate:0.00|ms", + "spark.timer.m15_rate:0.00|ms" + ) + // mean rate varies on each test run + val oneMoreResult = """spark.timer.mean_rate:\d+\.\d\d\|ms""" + + (1 to (expectedResults.size + 1)).foreach { i => + socket.receive(p) + val result = new String(p.getData, 0, p.getLength, UTF_8) + logInfo(s"Received timer result $i: '$result'") + assert(expectedResults.contains(result) || result.matches(oneMoreResult), + "Timer metric received should match data sent") + } + } + } +} + diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index fe8955840d72..21138bd4a16b 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -22,22 +22,23 @@ import java.nio._ import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit -import scala.concurrent.{Await, Promise} +import scala.concurrent.Promise import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} import com.google.common.io.CharStreams import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar -import org.scalatest.ShouldMatchers +import org.scalatest.Matchers +import org.scalatest.mockito.MockitoSugar import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.util.ThreadUtils -class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers { +class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with Matchers { test("security default off") { val conf = new SparkConf() .set("spark.app.id", "app-id") @@ -164,9 +165,9 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }) + }, null) - Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) + ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 271ab8b14883..f7bc3725d727 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.network.BlockDataManager class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach - with ShouldMatchers { + with Matchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ @@ -80,7 +80,8 @@ class NettyBlockTransferServiceSuite private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests - actualPort should be <= (expectedPort + 10) + // the default value for `spark.port.maxRetries` is 100 under test + actualPort should be <= (expectedPort + 100) } private def createService(port: Int): NettyBlockTransferService = { diff --git a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala index da3256bd882e..3c1208c2c375 100644 --- a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala @@ -23,21 +23,21 @@ class CountEvaluatorSuite extends SparkFunSuite { test("test count 0") { val evaluator = new CountEvaluator(10, 0.95) - assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity)) evaluator.merge(1, 0) - assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity)) } test("test count >= 1") { val evaluator = new CountEvaluator(10, 0.95) evaluator.merge(1, 1) - assert(new BoundedDouble(10.0, 0.95, 1.0, 36.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(10.0, 0.95, 5.0, 16.0)) evaluator.merge(1, 3) - assert(new BoundedDouble(20.0, 0.95, 7.0, 41.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(20.0, 0.95, 13.0, 28.0)) evaluator.merge(1, 8) - assert(new BoundedDouble(40.0, 0.95, 24.0, 61.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(40.0, 0.95, 30.0, 51.0)) (4 to 10).foreach(_ => evaluator.merge(1, 10)) - assert(new BoundedDouble(82.0, 1.0, 82.0, 82.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(82.0, 1.0, 82.0, 82.0)) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index b29a53cffeb5..de0e71a332f2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -20,17 +20,17 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore import scala.concurrent._ -import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.Duration import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.TimeLimits import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.util.ThreadUtils -class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { +class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with TimeLimits { @transient private var sc: SparkContext = _ @@ -130,10 +130,10 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim info("Should not have reached this code path (onComplete matching Failure)") throw new Exception("Task should succeed") } - f.onSuccess { case a: Any => + f.foreach { a => sem.release() } - f.onFailure { case t => + f.failed.foreach { t => info("Should not have reached this code path (onFailure)") throw new Exception("Task should succeed") } @@ -164,11 +164,11 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim case scala.util.Failure(e) => sem.release() } - f.onSuccess { case a: Any => + f.foreach { a => info("Should not have reached this code path (onSuccess)") throw new Exception("Task should fail") } - f.onFailure { case t => + f.failed.foreach { t => sem.release() } intercept[SparkException] { diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 2802cd975292..478f0690f8e4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.rdd +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} + import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -168,6 +173,10 @@ class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { // Collecting the RDD should now fail with an informative exception val blockId = RDDBlockId(rdd.id, numPartitions - 1) bmm.removeBlock(blockId) + // Wait until the block has been removed successfully. + eventually(timeout(1 seconds), interval(100 milliseconds)) { + assert(bmm.getBlockStatus(blockId).isEmpty) + } try { rdd.collect() fail("Collect should have failed if local checkpoint block is removed...") diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 02df157be377..44dd955ce869 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -561,7 +561,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { pairs.saveAsHadoopFile( "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf) } - assert(e.getMessage contains "failed to write") + assert(e.getCause.getMessage contains "failed to write") assert(FakeWriterWithCallback.calledBy === "write,callback,close") assert(FakeWriterWithCallback.exception != null, "exception should be captured") diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ad56715656c8..e994d724c462 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat} import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class RDDSuite extends SparkFunSuite with SharedSparkContext { var tempDir: File = _ @@ -192,6 +192,23 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(ser.serialize(union.partitions.head).limit() < 2000) } + test("fold") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def op: (Int, Int) => Int = (c: Int, x: Int) => c + x + val sum = rdd.fold(0)(op) + assert(sum === -1000) + } + + test("fold with op modifying first arg") { + val rdd = sc.makeRDD(-1000 until 1000, 10).map(x => Array(x)) + def op: (Array[Int], Array[Int]) => Array[Int] = { (c: Array[Int], x: Array[Int]) => + c(0) += x(0) + c + } + val sum = rdd.fold(Array(0))(op) + assert(sum(0) === -1000) + } + test("aggregate") { val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] @@ -218,7 +235,19 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { def combOp: (Long, Long) => Long = (c1: Long, c2: Long) => c1 + c2 for (depth <- 1 until 10) { val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) + assert(sum === -1000) + } + } + + test("treeAggregate with ops modifying first args") { + val rdd = sc.makeRDD(-1000 until 1000, 10).map(x => Array(x)) + def op: (Array[Int], Array[Int]) => Array[Int] = { (c: Array[Int], x: Array[Int]) => + c(0) += x(0) + c + } + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(Array(0))(op, op, depth) + assert(sum(0) === -1000) } } @@ -318,16 +347,18 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { val partitions = repartitioned.glom().collect() // assert all elements are present assert(repartitioned.collect().sortWith(_ > _).toSeq === input.toSeq.sortWith(_ > _).toSeq) - // assert no bucket is overloaded + // assert no bucket is overloaded or empty for (partition <- partitions) { val avg = input.size / finalPartitions val maxPossible = avg + initialPartitions - assert(partition.length <= maxPossible) + assert(partition.length <= maxPossible) + assert(!partition.isEmpty) } } testSplitPartitions(Array.fill(100)(1), 10, 20) testSplitPartitions(Array.fill(10000)(1) ++ Array.fill(10000)(2), 20, 100) + testSplitPartitions(Array.fill(1000)(1), 250, 128) } test("coalesced RDDs") { @@ -1082,6 +1113,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(totalPartitionCount == 10) } + test("SPARK-18406: race between end-of-task and completion iterator read lock release") { + val rdd = sc.parallelize(1 to 1000, 10) + rdd.cache() + + rdd.mapPartitions { iter => + ThreadUtils.runInNewThread("TestThread") { + // Iterate to the end of the input iterator, to cause the CompletionIterator completion to + // fire outside of the task's main thread. + while (iter.hasNext) { + iter.next() + } + iter + } + }.collect() + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 31d9dd3de8ac..a799b1cfb076 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -22,8 +22,8 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.UUID import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -633,7 +633,12 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("port conflict") { val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) - assert(anotherEnv.address.port != env.address.port) + try { + assert(anotherEnv.address.port != env.address.port) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } } private def testSend(conf: SparkConf): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 2b1bce4d208f..f9481f875d43 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.network.client.TransportClient @@ -31,7 +31,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { port: Int, clientMode: Boolean = false): RpcEnv = { val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port, - new SecurityManager(conf), clientMode) + new SecurityManager(conf), 0, clientMode) new NettyRpcEnvFactory().create(config) } @@ -47,7 +47,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { test("advertise address different from bind address") { val sparkConf = new SparkConf() val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0, - new SecurityManager(sparkConf), false) + new SecurityManager(sparkConf), 0, false) val env = new NettyRpcEnvFactory().create(config) try { assert(env.address.hostPort.startsWith("example.com:")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index f6015cd51c2b..d3bbfd11d406 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -115,8 +115,9 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) awaitJobTermination(jobFuture, duration) - val pattern = ("Aborting TaskSet 0.0 because task .* " + - "cannot run anywhere due to node and executor blacklist").r + val pattern = ( + s"""|Aborting TaskSet 0.0 because task .* + |cannot run anywhere due to node and executor blacklist""".stripMargin).r assert(pattern.findFirstIn(failure.getMessage).isDefined, s"Couldn't find $pattern in ${failure.getMessage()}") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 2b18ebee79a2..cd1b7a9e5ab1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.scheduler -import org.mockito.invocation.InvocationOnMock import org.mockito.Matchers.any import org.mockito.Mockito.{never, verify, when} +import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.internal.config @@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M sc = new SparkContext(conf) val scheduler = mock[TaskSchedulerImpl] when(scheduler.sc).thenReturn(sc) - when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker) + when(scheduler.mapOutputTracker).thenReturn( + SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]) scheduler } @@ -109,7 +110,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val taskSetBlacklist = createTaskSetBlacklist(stageId) if (stageId % 2 == 0) { // fail one task in every other taskset - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") failuresSoFar += 1 } blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) @@ -131,7 +133,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // for many different stages, executor 1 fails a task, and then the taskSet fails. (0 until failuresUntilBlacklisted * 10).foreach { stage => val taskSetBlacklist = createTaskSetBlacklist(stage) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") } assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) } @@ -146,7 +149,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val numFailures = math.max(conf.get(config.MAX_FAILURES_PER_EXEC), conf.get(config.MAX_FAILURES_PER_EXEC_STAGE)) (0 until numFailures).foreach { index => - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = index) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = index, failureReason = "testing") } assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) @@ -169,7 +173,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) assert(blacklist.nodeBlacklist() === Set()) @@ -182,7 +187,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) assert(blacklist.nodeBlacklist() === Set("hostA")) @@ -206,7 +212,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail one more task, but executor isn't put back into blacklist since the count of failures // on that executor should have been reset to 0. val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(2, 0, taskSetBlacklist2.execToFailures) assert(blacklist.nodeBlacklist() === Set()) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) @@ -220,7 +227,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Lets say that executor 1 dies completely. We get some task failures, but // the taskset then finishes successfully (elsewhere). (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.handleRemovedExecutor("1") blacklist.updateBlacklistForSuccessfulTaskSet( @@ -235,7 +243,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Now another executor gets spun up on that host, but it also dies. val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.handleRemovedExecutor("2") blacklist.updateBlacklistForSuccessfulTaskSet( @@ -278,7 +287,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M def failOneTaskInTaskSet(exec: String): Unit = { val taskSetBlacklist = createTaskSetBlacklist(stageId = stageId) - taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0) + taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0, "testing") blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) stageId += 1 } @@ -353,12 +362,12 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) // Taskset1 has one failure immediately - taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0) + taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0, "testing") // Then we have a *long* delay, much longer than the timeout, before any other failures or // taskset completion clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS * 5) // After the long delay, we have one failure on taskset 2, on the same executor - taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0) + taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0, "testing") // Finally, we complete both tasksets. Its important here to complete taskset2 *first*. We // want to make sure that when taskset 1 finishes, even though we've now got two task failures, // we realize that the task failure we just added was well before the timeout. @@ -376,16 +385,20 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // we blacklist executors on two different hosts -- make sure that doesn't lead to any // node blacklisting val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "1", 2)) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) - taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) - taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 0, failureReason = "testing") + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(1, 0, taskSetBlacklist1.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "2", 2)) @@ -394,8 +407,10 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Finally, blacklist another executor on the same node as the original blacklisted executor, // and make sure this time we *do* blacklist the node. val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 0) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 0) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 1) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 0, failureReason = "testing") + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2", "3")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "3", 2)) @@ -485,7 +500,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) @@ -496,7 +512,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) @@ -511,7 +528,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) @@ -522,11 +540,67 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist3.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist3.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) verify(allocationClientMock).killExecutors(Seq("2"), true, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } + + test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { + val allocationClientMock = mock[ExecutorAllocationClient] + when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { + // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist + // is updated before we ask the executor allocation client to kill all the executors + // on a particular host. + override def answer(invocation: InvocationOnMock): Boolean = { + if (blacklist.nodeBlacklist.contains("hostA") == false) { + throw new IllegalStateException("hostA should be on the blacklist") + } + true + } + }) + + conf.set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + + // Disable auto-kill. Blacklist an executor and make sure killExecutors is not called. + conf.set(config.BLACKLIST_KILL_ENABLED, false) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") + + verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. + conf.set(config.BLACKLIST_KILL_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + clock.advance(1000) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") + + verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + assert(blacklist.executorIdToBlacklistStatus.contains("1")) + assert(blacklist.executorIdToBlacklistStatus("1").node === "hostA") + assert(blacklist.executorIdToBlacklistStatus("1").expiryTime === + 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + + // Enable external shuffle service to see if all the executors on this node will be killed. + conf.set(config.SHUFFLE_SERVICE_ENABLED, true) + clock.advance(1000) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "2") + + verify(allocationClientMock, never).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutorsOnHost("hostA") + + assert(blacklist.nodeIdToBlacklistExpiryTime.contains("hostA")) + assert(blacklist.nodeIdToBlacklistExpiryTime("hostA") === + 2000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a10941b579fe..6222e576d1ce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.scheduler import java.util.Properties -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.TimeLimits import org.scalatest.time.SpanSugar._ import org.apache.spark._ @@ -98,7 +98,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeouts { +class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLimits { import DAGSchedulerSuite._ @@ -131,6 +131,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None } @@ -396,6 +397,73 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + test("All shuffle files on the slave should be cleaned up when slave lost") { + // reset the test context with the right shuffle service config + afterEach() + val conf = new SparkConf() + conf.set("spark.shuffle.service.enabled", "true") + conf.set("spark.files.fetchFailure.unRegisterOutputOnHost", "true") + init(conf) + runEvent(ExecutorAdded("exec-hostA1", "hostA")) + runEvent(ExecutorAdded("exec-hostA2", "hostA")) + runEvent(ExecutorAdded("exec-hostB", "hostB")) + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(3)) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(3)) + val secondShuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + // map stage1 completes successfully, with one task on each executor + complete(taskSets(0), Seq( + (Success, + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, makeMapStatus("hostB", 1)) + )) + // map stage2 completes successfully, with one task on each executor + complete(taskSets(1), Seq( + (Success, + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, makeMapStatus("hostB", 1)) + )) + // make sure our test setup is correct + val initialMapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses + // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get + assert(initialMapStatus1.count(_ != null) === 3) + assert(initialMapStatus1.map{_.location.executorId}.toSet === + Set("exec-hostA1", "exec-hostA2", "exec-hostB")) + + val initialMapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses + // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get + assert(initialMapStatus2.count(_ != null) === 3) + assert(initialMapStatus2.map{_.location.executorId}.toSet === + Set("exec-hostA1", "exec-hostA2", "exec-hostB")) + + // reduce stage fails with a fetch failure from one host + complete(taskSets(2), Seq( + (FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345), firstShuffleId, 0, 0, "ignored"), + null) + )) + + // Here is the main assertion -- make sure that we de-register + // the map outputs for both map stage from both executors on hostA + + val mapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses + assert(mapStatus1.count(_ != null) === 1) + assert(mapStatus1(2).location.executorId === "exec-hostB") + assert(mapStatus1(2).location.host === "hostB") + + val mapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses + assert(mapStatus2.count(_ != null) === 1) + assert(mapStatus2(2).location.executorId === "exec-hostB") + assert(mapStatus2(2).location.host === "hostB") + } + test("zero split job") { var numResults = 0 var failureReason: Option[Exception] = None @@ -565,6 +633,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None } val noKillScheduler = new DAGScheduler( @@ -682,7 +751,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // Helper functions to extract commonly used code in Fetch Failure test cases private def setupStageAbortTest(sc: SparkContext) { - sc.listenerBus.addListener(new EndListener()) + sc.listenerBus.addToSharedQueue(new EndListener()) ended = false jobResult = null } @@ -1277,10 +1346,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou */ test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(3)) val firstShuffleId = firstShuffleDep.shuffleId val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1583,7 +1652,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou */ test("run trivial shuffle with out-of-band executor failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) @@ -1791,7 +1860,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou test("reduce tasks should be placed locally with map output") { // Create a shuffleMapRdd with 1 partition val shuffleMapRdd = new MyRDD(sc, 1, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) @@ -2277,6 +2346,36 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou (Success, 1))) } + test("task end event should have updated accumulators (SPARK-20342)") { + val tasks = 10 + + val accumId = new AtomicLong() + val foundCount = new AtomicLong() + val listener = new SparkListener() { + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + event.taskInfo.accumulables.find(_.id == accumId.get).foreach { _ => + foundCount.incrementAndGet() + } + } + } + sc.addSparkListener(listener) + + // Try a few times in a loop to make sure. This is not guaranteed to fail when the bug exists, + // but it should at least make the test flaky. If the bug is fixed, this should always pass. + (1 to 10).foreach { i => + foundCount.set(0L) + + val accum = sc.longAccumulator(s"accum$i") + accumId.set(accum.id) + + sc.parallelize(1 to tasks, tasks).foreach { _ => + accum.add(1L) + } + sc.listenerBus.waitUntilEmpty(1000) + assert(foundCount.get() === tasks) + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 4c3d0b102152..6b42775ccb0f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -18,19 +18,20 @@ package org.apache.spark.scheduler import java.io.{File, FileOutputStream, InputStream, IOException} -import java.net.URI import scala.collection.mutable import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.io._ +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{JsonProtocol, Utils} /** @@ -50,7 +51,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit private var testDirPath: Path = _ before { - testDir = Utils.createTempDir() + testDir = Utils.createTempDir(namePrefix = s"history log") testDir.deleteOnExit() testDirPath = new Path(testDir.getAbsolutePath()) } @@ -109,7 +110,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit test("Log overwriting") { val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test", None) - val logPath = new URI(logUri).getPath + val logPath = new Path(logUri).toUri.getPath // Create file before writing the event log new FileOutputStream(new File(logPath)).close() // Expected IOException, since we haven't enabled log overwrite. @@ -155,17 +156,18 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf) - val listenerBus = new LiveListenerBus(sc) + val listenerBus = new LiveListenerBus(conf) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite eventLogger.start() - listenerBus.start() - listenerBus.addListener(eventLogger) - listenerBus.postToAll(applicationStart) - listenerBus.postToAll(applicationEnd) + listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem])) + listenerBus.addToEventLogQueue(eventLogger) + listenerBus.post(applicationStart) + listenerBus.post(applicationEnd) + listenerBus.stop() eventLogger.stop() // Verify file contains exactly the two events logged @@ -290,7 +292,7 @@ object EventLoggingListenerSuite { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") - conf.set("spark.eventLog.dir", logDir.toUri.toString) + conf.set("spark.eventLog.dir", logDir.toString) compressionCodec.foreach { codec => conf.set("spark.eventLog.compress", "true") conf.set("spark.io.compression.codec", codec) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index ba56af8215cd..a4e4ea7cd289 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -84,6 +84,7 @@ private class DummyTaskScheduler extends TaskScheduler { override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None def executorHeartbeatReceived( execId: String, diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 759d52fca5ce..144e5afdcdd7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -17,12 +17,17 @@ package org.apache.spark.scheduler +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import scala.util.Random +import org.mockito.Mockito._ import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.internal.config +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { @@ -128,4 +133,37 @@ class MapStatusSuite extends SparkFunSuite { assert(size1 === size2) assert(!success) } + + test("Blocks which are bigger than SHUFFLE_ACCURATE_BLOCK_THRESHOLD should not be " + + "underestimated.") { + val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000") + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + // Value of element in sizes is equal to the corresponding index. + val sizes = (0L to 2000L).toArray + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val arrayStream = new ByteArrayOutputStream(102400) + val objectOutputStream = new ObjectOutputStream(arrayStream) + assert(status1.isInstanceOf[HighlyCompressedMapStatus]) + objectOutputStream.writeObject(status1) + objectOutputStream.flush() + val array = arrayStream.toByteArray + val objectInput = new ObjectInputStream(new ByteArrayInputStream(array)) + val status2 = objectInput.readObject().asInstanceOf[HighlyCompressedMapStatus] + (1001 to 2000).foreach { + case part => assert(status2.getSizeForBlock(part) >= sizes(part)) + } + } + + test("SPARK-21133 HighlyCompressedMapStatus#writeExternal throws NPE") { + val conf = new SparkConf() + .set("spark.serializer", classOf[KryoSerializer].getName) + .setMaster("local") + .setAppName("SPARK-21133") + withSpark(new SparkContext(conf)) { sc => + val count = sc.parallelize(0 until 3000, 10).repartition(2001).collect().length + assert(count === 3000) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index 32cdf16dd331..a27dadcf49bf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} -import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.TimeLimits import org.scalatest.time.{Seconds, Span} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext} @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils class OutputCommitCoordinatorIntegrationSuite extends SparkFunSuite with LocalSparkContext - with Timeouts { + with TimeLimits { override def beforeAll(): Unit = { super.beforeAll() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e51e6a0d3ff6..03b190390249 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.scheduler import java.io.File +import java.util.Date import java.util.concurrent.TimeoutException import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.hadoop.mapred.{JobConf, OutputCommitter, TaskAttemptContext, TaskAttemptID} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.TaskType import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -31,7 +33,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.internal.io.SparkHadoopWriter +import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.{ThreadUtils, Utils} @@ -113,7 +115,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = { if (!hasDequeuedSpeculatedTask) { hasDequeuedSpeculatedTask = true - Some(0, TaskLocality.PROCESS_LOCAL) + Some((0, TaskLocality.PROCESS_LOCAL)) } else { None } @@ -214,6 +216,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { */ private case class OutputCommitFunctions(tempDirPath: String) { + private val jobId = new SerializableWritable(SparkHadoopWriterUtils.createJobID(new Date, 0)) + // Mock output committer that simulates a successful commit (after commit is authorized) private def successfulOutputCommitter = new FakeOutputCommitter { override def commitTask(context: TaskAttemptContext): Unit = { @@ -256,14 +260,22 @@ private case class OutputCommitFunctions(tempDirPath: String) { def jobConf = new JobConf { override def getOutputCommitter(): OutputCommitter = outputCommitter } - val sparkHadoopWriter = new SparkHadoopWriter(jobConf) { - override def newTaskAttemptContext( - conf: JobConf, - attemptId: TaskAttemptID): TaskAttemptContext = { - mock(classOf[TaskAttemptContext]) - } - } - sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber) - sparkHadoopWriter.commit() + + // Instantiate committer. + val committer = FileCommitProtocol.instantiate( + className = classOf[HadoopMapRedCommitProtocol].getName, + jobId = jobId.value.getId.toString, + outputPath = jobConf.get("mapred.output.dir")) + + // Create TaskAttemptContext. + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val taskAttemptId = (ctx.taskAttemptId % Int.MaxValue).toInt + val attemptId = new TaskAttemptID( + new TaskID(jobId.value, TaskType.MAP, ctx.partitionId), taskAttemptId) + val taskContext = new TaskAttemptContextImpl(jobConf, attemptId) + + committer.setupTask(taskContext) + committer.commitTask(taskContext) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 4901062a7855..5bd3955f5adb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.FileNotFoundException import java.util.Properties import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} @@ -292,6 +293,49 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { } } + test("Fair Scheduler should build fair scheduler when " + + "valid spark.scheduler.allocation.file property is set") { + val xmlPath = getClass.getClassLoader.getResource("fairscheduler-with-valid-data.xml").getFile() + val conf = new SparkConf().set(SCHEDULER_ALLOCATION_FILE_PROPERTY, xmlPath) + sc = new SparkContext(LOCAL, APP_NAME, conf) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + verifyPool(rootPool, schedulableBuilder.DEFAULT_POOL_NAME, 0, 1, FIFO) + verifyPool(rootPool, "pool1", 3, 1, FIFO) + verifyPool(rootPool, "pool2", 4, 2, FAIR) + verifyPool(rootPool, "pool3", 2, 3, FAIR) + } + + test("Fair Scheduler should use default file(fairscheduler.xml) if it exists in classpath " + + "and spark.scheduler.allocation.file property is not set") { + val conf = new SparkConf() + sc = new SparkContext(LOCAL, APP_NAME, conf) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + verifyPool(rootPool, schedulableBuilder.DEFAULT_POOL_NAME, 0, 1, FIFO) + verifyPool(rootPool, "1", 2, 1, FIFO) + verifyPool(rootPool, "2", 3, 1, FIFO) + verifyPool(rootPool, "3", 0, 1, FIFO) + } + + test("Fair Scheduler should throw FileNotFoundException " + + "when invalid spark.scheduler.allocation.file property is set") { + val conf = new SparkConf().set(SCHEDULER_ALLOCATION_FILE_PROPERTY, "INVALID_FILE_PATH") + sc = new SparkContext(LOCAL, APP_NAME, conf) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + intercept[FileNotFoundException] { + schedulableBuilder.buildPools() + } + } + private def verifyPool(rootPool: Pool, poolName: String, expectedInitMinShare: Int, expectedInitWeight: Int, expectedSchedulingMode: SchedulingMode): Unit = { val selectedPool = rootPool.getSchedulableByName(poolName) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 1732aca9417e..d17e3864854a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.scheduler -import java.io.{File, PrintWriter} +import java.io._ import java.net.URI +import java.util.concurrent.atomic.AtomicInteger +import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.io.CompressionCodec +import org.apache.spark.io.{CompressionCodec, LZ4CompressionCodec} import org.apache.spark.util.{JsonProtocol, JsonProtocolSuite, Utils} /** @@ -72,6 +74,60 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp assert(eventMonster.loggedEvents(1) === JsonProtocol.sparkEventToJson(applicationEnd)) } + /** + * Test replaying compressed spark history file that internally throws an EOFException. To + * avoid sensitivity to the compression specifics the test forces an EOFException to occur + * while reading bytes from the underlying stream (such as observed in actual history files + * in some cases) and forces specific failure handling. This validates correctness in both + * cases when maybeTruncated is true or false. + */ + test("Replay compressed inprogress log file succeeding on partial read") { + val buffered = new ByteArrayOutputStream + val codec = new LZ4CompressionCodec(new SparkConf()) + val compstream = codec.compressedOutputStream(buffered) + Utils.tryWithResource(new PrintWriter(compstream)) { writer => + + val applicationStart = SparkListenerApplicationStart("AppStarts", None, + 125L, "Mickey", None) + val applicationEnd = SparkListenerApplicationEnd(1000L) + + // scalastyle:off println + writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) + writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) + // scalastyle:on println + } + + val logFilePath = Utils.getFilePath(testDir, "events.lz4.inprogress") + val bytes = buffered.toByteArray + Utils.tryWithResource(fileSystem.create(logFilePath)) { fstream => + fstream.write(bytes, 0, buffered.size) + } + + // Read the compressed .inprogress file and verify only first event was parsed. + val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath) + val replayer = new ReplayListenerBus() + + val eventMonster = new EventMonster(conf) + replayer.addListener(eventMonster) + + // Verify the replay returns the events given the input maybe truncated. + val logData = EventLoggingListener.openEventLog(logFilePath, fileSystem) + Utils.tryWithResource(new EarlyEOFInputStream(logData, buffered.size - 10)) { failingStream => + replayer.replay(failingStream, logFilePath.toString, true) + + assert(eventMonster.loggedEvents.size === 1) + assert(failingStream.didFail) + } + + // Verify the replay throws the EOF exception since the input may not be truncated. + val logData2 = EventLoggingListener.openEventLog(logFilePath, fileSystem) + Utils.tryWithResource(new EarlyEOFInputStream(logData2, buffered.size - 10)) { failingStream2 => + intercept[EOFException] { + replayer.replay(failingStream2, logFilePath.toString, false) + } + } + } + // This assumes the correctness of EventLoggingListener test("End-to-end replay") { testApplicationReplay() @@ -97,7 +153,10 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp * assumption that the event logging behavior is correct (tested in a separate suite). */ private def testApplicationReplay(codecName: Option[String] = None) { - val logDirPath = Utils.getFilePath(testDir, "test-replay") + val logDir = new File(testDir.getAbsolutePath, "test-replay") + // Here, it creates `Path` from the URI instead of the absolute path for the explicit file + // scheme so that the string representation of this `Path` has leading file scheme correctly. + val logDirPath = new Path(logDir.toURI) fileSystem.mkdirs(logDirPath) val conf = EventLoggingListenerSuite.getLoggingConf(logDirPath, codecName) @@ -156,4 +215,25 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp override def start() { } } + + /* + * This is a dummy input stream that wraps another input stream but ends prematurely when + * reading at the specified position, throwing an EOFExeption. + */ + private class EarlyEOFInputStream(in: InputStream, failAtPos: Int) extends InputStream { + private val countDown = new AtomicInteger(failAtPos) + + def didFail: Boolean = countDown.get == 0 + + @throws[IOException] + override def read(): Int = { + if (countDown.get == 0) { + throw new EOFException("Stream ended prematurely") + } + countDown.decrementAndGet() + in.read() + } + + override def close(): Unit = in.close() + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 8300607ea888..75ea409e16b4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.util.concurrent.{TimeoutException, TimeUnit} import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration.{Duration, SECONDS} import scala.language.existentials import scala.reflect.ClassTag @@ -260,7 +260,7 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa */ def awaitJobTermination(jobFuture: Future[_], duration: Duration): Unit = { try { - Await.ready(jobFuture, duration) + ThreadUtils.awaitReady(jobFuture, duration) } catch { case te: TimeoutException if backendException.get() != null => val msg = raw""" @@ -553,10 +553,10 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor */ testScheduler("multi-stage job") { - def stageToOutputParts(stageId: Int): Int = { - stageId match { + def shuffleIdToOutputParts(shuffleId: Int): Int = { + shuffleId match { case 0 => 10 - case 2 => 20 + case 1 => 20 case _ => 30 } } @@ -577,11 +577,12 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor // b/c the stage numbering is non-deterministic, so stage number alone doesn't tell // us what to check } - (task.stageId, task.stageAttemptId, task.partitionId) match { case (stage, 0, _) if stage < 4 => + val shuffleId = + scheduler.stageIdToStage(stage).asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId backend.taskSuccess(taskDescription, - DAGSchedulerSuite.makeMapStatus("hostA", stageToOutputParts(stage))) + DAGSchedulerSuite.makeMapStatus("hostA", shuffleIdToOutputParts(shuffleId))) case (4, 0, partition) => backend.taskSuccess(taskDescription, 4321 + partition) } @@ -624,6 +625,8 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor backend.taskFailed(taskDescription, fetchFailed) case (1, _, partition) => backend.taskSuccess(taskDescription, 42 + partition) + case unmatched => + fail(s"Unexpected shuffle output $unmatched") } } withBackend(runBackend _) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index f5575ce1e157..d061c7845f4a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -19,68 +19,101 @@ package org.apache.spark.scheduler import java.util.concurrent.Semaphore -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable +import org.mockito.Mockito import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.LISTENER_BUS_EVENT_QUEUE_CAPACITY +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { + import LiveListenerBus._ + /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val jobCompletionTime = 1421191296660L + private val mockSparkContext: SparkContext = Mockito.mock(classOf[SparkContext]) + private val mockMetricsSystem: MetricsSystem = Mockito.mock(classOf[MetricsSystem]) + + private def numDroppedEvents(bus: LiveListenerBus): Long = { + bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount + } + + private def queueSize(bus: LiveListenerBus): Int = { + bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() + .asInstanceOf[Int] + } + + private def eventProcessingTimeCount(bus: LiveListenerBus): Long = { + bus.metrics.metricRegistry.timer(s"queue.$SHARED_QUEUE.listenerProcessingTime").getCount() + } + test("don't call sc.stop in listener") { sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val listener = new SparkContextStoppingListener(sc) - val bus = new LiveListenerBus(sc) - bus.addListener(listener) - // Starting listener bus should flush all buffered events - bus.start() - bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + sc.listenerBus.addToSharedQueue(listener) + sc.listenerBus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + sc.stop() - bus.stop() assert(listener.sparkExSeen) } test("basic creation and shutdown of LiveListenerBus") { - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) + val conf = new SparkConf() val counter = new BasicJobCounter - val bus = new LiveListenerBus(sc) - bus.addListener(counter) + val bus = new LiveListenerBus(conf) + bus.addToSharedQueue(counter) - // Listener bus hasn't started yet, so posting events should not increment counter + // Metrics are initially empty. + assert(bus.metrics.numEventsPosted.getCount === 0) + assert(numDroppedEvents(bus) === 0) + assert(queueSize(bus) === 0) + assert(eventProcessingTimeCount(bus) === 0) + + // Post five events: (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + + // Five messages should be marked as received and queued, but no messages should be posted to + // listeners yet because the the listener bus hasn't been started. + assert(bus.metrics.numEventsPosted.getCount === 5) + assert(queueSize(bus) === 5) assert(counter.count === 0) // Starting listener bus should flush all buffered events - bus.start() + bus.start(mockSparkContext, mockMetricsSystem) + Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) + assert(queueSize(bus) === 0) + assert(eventProcessingTimeCount(bus) === 5) // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 5) + assert(eventProcessingTimeCount(bus) === 5) // Listener bus must not be started twice intercept[IllegalStateException] { - val bus = new LiveListenerBus(sc) - bus.start() - bus.start() + val bus = new LiveListenerBus(conf) + bus.start(mockSparkContext, mockMetricsSystem) + bus.start(mockSparkContext, mockMetricsSystem) } // ... or stopped before starting intercept[IllegalStateException] { - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(conf) bus.stop() } } @@ -107,12 +140,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match drained = true } } - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(new SparkConf()) val blockingListener = new BlockingListener - bus.addListener(blockingListener) - bus.start() + bus.addToSharedQueue(blockingListener) + bus.start(mockSparkContext, mockMetricsSystem) bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() @@ -138,6 +170,43 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(drained) } + test("metrics for dropped listener events") { + val bus = new LiveListenerBus(new SparkConf().set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 1)) + + val listenerStarted = new Semaphore(0) + val listenerWait = new Semaphore(0) + + bus.addToSharedQueue(new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + listenerStarted.release() + listenerWait.acquire() + } + }) + + bus.start(mockSparkContext, mockMetricsSystem) + + // Post a message to the listener bus and wait for processing to begin: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + listenerStarted.acquire() + assert(queueSize(bus) === 0) + assert(numDroppedEvents(bus) === 0) + + // If we post an additional message then it should remain in the queue because the listener is + // busy processing the first event: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + assert(queueSize(bus) === 1) + assert(numDroppedEvents(bus) === 0) + + // The queue is now full, so any additional events posted to the listener will be dropped: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + assert(queueSize(bus) === 1) + assert(numDroppedEvents(bus) === 1) + + // Allow the the remaining events to be processed so we can stop the listener bus: + listenerWait.release(2) + bus.stop() + } + test("basic creation of StageInfo") { sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo @@ -184,7 +253,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get - stageInfo2.rddInfos.size should be {3} // ParallelCollectionRDD, FilteredRDD, MappedRDD + stageInfo2.rddInfos.size should be {3} stageInfo2.rddInfos.forall(_.numPartitions == 4) should be {true} stageInfo2.rddInfos.exists(_.name == "Deux") should be {true} listener.stageInfos.clear() @@ -237,7 +306,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") val d3 = d.map { i => w(i) -> (0 to (i % 5)) }.setName("shuffle input 2") val d4 = d2.cogroup(d3, numSlices).map { case (k, (v1, v2)) => - w(k) -> (v1.size, v2.size) + (w(k), (v1.size, v2.size)) } d4.setName("A Cogroup") d4.collectAsMap() @@ -354,21 +423,19 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val badListener = new BadListener val jobCounter1 = new BasicJobCounter val jobCounter2 = new BasicJobCounter - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(new SparkConf()) // Propagate events to bad listener first - bus.addListener(badListener) - bus.addListener(jobCounter1) - bus.addListener(jobCounter2) - bus.start() + bus.addToSharedQueue(badListener) + bus.addToSharedQueue(jobCounter1) + bus.addToSharedQueue(jobCounter2) + bus.start(mockSparkContext, mockMetricsSystem) // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // The exception should be caught, and the event should be propagated to other listeners - assert(bus.listenerThreadIsAlive) assert(jobCounter1.count === 5) assert(jobCounter2.count === 5) } @@ -388,6 +455,31 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } + test("add and remove listeners to/from LiveListenerBus queues") { + val bus = new LiveListenerBus(new SparkConf(false)) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val counter3 = new BasicJobCounter() + + bus.addToSharedQueue(counter1) + bus.addToStatusQueue(counter2) + bus.addToStatusQueue(counter3) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 3) + + bus.removeListener(counter1) + assert(bus.activeQueues() === Set(APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + + bus.removeListener(counter2) + assert(bus.activeQueues() === Set(APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 1) + + bus.removeListener(counter3) + assert(bus.activeQueues().isEmpty) + assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index b22da565d86e..a1d9085fa085 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -54,7 +54,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addTaskCompletionListener(context => TaskContextSuite.completed = true) + context.addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = + TaskContextSuite.completed = true + }) sys.error("failed") } } @@ -95,12 +98,16 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("all TaskCompletionListeners should be called even if some fail") { val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) - context.addTaskCompletionListener(_ => throw new Exception("blah")) + context.addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = throw new Exception("blah") + }) context.addTaskCompletionListener(listener) - context.addTaskCompletionListener(_ => throw new Exception("blah")) + context.addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = throw new Exception("blah") + }) intercept[TaskCompletionListenerException] { - context.markTaskCompleted() + context.markTaskCompleted(None) } verify(listener, times(1)).onTaskCompletion(any()) @@ -109,9 +116,15 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("all TaskFailureListeners should be called even if some fail") { val context = TaskContext.empty() val listener = mock(classOf[TaskFailureListener]) - context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1")) + context.addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = + throw new Exception("exception in listener1") + }) context.addTaskFailureListener(listener) - context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3")) + context.addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = + throw new Exception("exception in listener3") + }) val e = intercept[TaskCompletionListenerException] { context.markTaskFailed(new Exception("exception in task")) @@ -231,10 +244,13 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("immediately call a completion listener if the context is completed") { var invocations = 0 val context = TaskContext.empty() - context.markTaskCompleted() - context.addTaskCompletionListener(_ => invocations += 1) + context.markTaskCompleted(None) + context.addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = + invocations += 1 + }) assert(invocations == 1) - context.markTaskCompleted() + context.markTaskCompleted(None) assert(invocations == 1) } @@ -244,16 +260,54 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val error = new RuntimeException val context = TaskContext.empty() context.markTaskFailed(error) - context.addTaskFailureListener { (_, e) => - lastError = e - invocations += 1 - } + context.addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, e: Throwable): Unit = { + lastError = e + invocations += 1 + } + }) assert(lastError == error) assert(invocations == 1) context.markTaskFailed(error) assert(lastError == error) assert(invocations == 1) } + + test("TaskCompletionListenerException.getMessage should include previousError") { + val listenerErrorMessage = "exception in listener" + val taskErrorMessage = "exception in task" + val e = new TaskCompletionListenerException( + Seq(listenerErrorMessage), + Some(new RuntimeException(taskErrorMessage))) + assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage)) + } + + test("all TaskCompletionListeners should be called even if some fail or a task") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = + throw new Exception("exception in listener1") + }) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = + throw new Exception("exception in listener3") + }) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskCompleted(Some(new Exception("exception in task"))) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskCompletion(any()) + + // also need to check failure in TaskCompletionListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 3e55d399e9df..1bddba8f6c82 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -34,8 +34,8 @@ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.storage.TaskResultBlockId import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.storage.TaskResultBlockId import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 8b9d45f734cd..6003899bb7be 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -24,11 +24,11 @@ import scala.collection.mutable.HashMap import org.mockito.Matchers.{anyInt, anyObject, anyString, eq => meq} import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when} import org.scalatest.BeforeAndAfterEach -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.apache.spark._ -import org.apache.spark.internal.config import org.apache.spark.internal.Logging +import org.apache.spark.internal.config import org.apache.spark.util.ManualClock class FakeSchedulerBackend extends SchedulerBackend { @@ -87,7 +87,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B conf.set(config.BLACKLIST_ENABLED, true) sc = new SparkContext(conf) taskScheduler = - new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4), Some(blacklist)) { + new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4)) { override def createTaskSetManager(taskSet: TaskSet, maxFailures: Int): TaskSetManager = { val tsm = super.createTaskSetManager(taskSet, maxFailures) // we need to create a spied tsm just so we can set the TaskSetBlacklist @@ -98,6 +98,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B stageToMockTaskSetBlacklist(taskSet.stageId) = taskSetBlacklist tsmSpy } + + override private[scheduler] lazy val blacklistTrackerOpt = Some(blacklist) } setupHelper() } @@ -658,9 +660,14 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(tsm.isZombie) assert(failedTaskSet) val idx = failedTask.index - assert(failedTaskSetReason === s"Aborting TaskSet 0.0 because task $idx (partition $idx) " + - s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior can be " + - s"configured via spark.blacklist.*.") + assert(failedTaskSetReason === s""" + |Aborting $taskSet because task $idx (partition $idx) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${tsm.taskSetBlacklistHelperOpt.get.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) } test("don't abort if there is an executor available, though it hasn't had scheduled tasks yet") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index 6b52c10b2c68..18981d5be2f9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -37,7 +37,8 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // First, mark task 0 as failed on exec1. // task 0 should be blacklisted on exec1, and nowhere else - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec1", index = 0, failureReason = "testing") for { executor <- (1 to 4).map(_.toString) index <- 0 until 10 @@ -49,17 +50,20 @@ class TaskSetBlacklistSuite extends SparkFunSuite { assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to // blacklisting the entire node. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) @@ -86,8 +90,8 @@ class TaskSetBlacklistSuite extends SparkFunSuite { Seq("exec1", "exec2").foreach { exec => assert( execToFailures(exec).taskToFailureCountAndFailureTime === Map( - 0 -> (1, 0), - 1 -> (1, 0) + 0 -> ((1, 0)), + 1 -> ((1, 0)) ) ) } @@ -108,34 +112,41 @@ class TaskSetBlacklistSuite extends SparkFunSuite { .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) // Fail a task twice on hostA, exec:1 - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail the same task once more on hostA, exec:2 - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, // so its blacklisted - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 2) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "2", index = 2, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are // blacklisted for the taskset, so blacklist the whole node. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 3) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 4) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 3, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 4, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) } @@ -147,13 +158,17 @@ class TaskSetBlacklistSuite extends SparkFunSuite { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) - taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ca6b8b0fe63..5c712bd6a545 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -23,16 +23,16 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.mockito.Matchers.{any, anyInt, anyString} -import org.mockito.Mockito.{mock, never, spy, verify, when} +import org.mockito.Mockito.{mock, never, spy, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.apache.spark._ -import org.apache.spark.internal.config import org.apache.spark.internal.Logging +import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerInstance import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AccumulatorV2, ManualClock} +import org.apache.spark.util.{AccumulatorV2, ManualClock, Utils} class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -60,6 +60,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) exception: Option[Throwable]): Unit = { taskScheduler.taskSetsFailed += taskSet.id } + + override def speculativeTaskSubmitted(task: Task[_]): Unit = { + taskScheduler.speculativeTasks += task.partitionId + } } // Get the rack for a given host @@ -92,6 +96,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val endedTasks = new mutable.HashMap[Long, TaskEndReason] val finishedManagers = new ArrayBuffer[TaskSetManager] val taskSetsFailed = new ArrayBuffer[String] + val speculativeTasks = new ArrayBuffer[Int] val executors = new mutable.HashMap[String, String] for ((execId, host) <- liveExecutors) { @@ -139,6 +144,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex } } + override def getRackForHost(value: String): Option[String] = FakeRackUtil.getRackForHost(value) } @@ -929,6 +935,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // > 0ms, so advance the clock by 1ms here. clock.advance(1) assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(3)) + // Offer resource to start the speculative attempt for the running task val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption5.isDefined) @@ -1016,6 +1024,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // > 0ms, so advance the clock by 1ms here. clock.advance(1) assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(3, 4)) // Offer resource to start the speculative attempt for the running task val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption5.isDefined) @@ -1070,11 +1079,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) - when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).then(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - assert(manager.isZombie === true) - } - }) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie) + } + }) val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon @@ -1136,7 +1146,113 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Make sure that the blacklist ignored all of the task failures above, since they aren't // the fault of the executor where the task was running. verify(blacklist, never()) - .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) + .updateBlacklistForFailedTask(anyString(), anyString(), anyInt(), anyString()) + } + + test("update application blacklist for shuffle-fetch") { + // Setup a taskset, and fail some one task for fetch failure. + val conf = new SparkConf() + .set(config.BLACKLIST_ENABLED, true) + .set(config.SHUFFLE_SERVICE_ENABLED, true) + .set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true) + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + val blacklistTracker = new BlacklistTracker(sc, None) + val tsm = new TaskSetManager(sched, taskSet, 4, Some(blacklistTracker)) + + // make some offers to our taskset, to get tasks we will fail + val taskDescs = Seq( + "exec1" -> "host1", + "exec2" -> "host2" + ).flatMap { case (exec, host) => + // offer each executor twice (simulating 2 cores per executor) + (0 until 2).flatMap{ _ => tsm.resourceOffer(exec, host, TaskLocality.ANY)} + } + assert(taskDescs.size === 4) + + assert(!blacklistTracker.isExecutorBlacklisted(taskDescs(0).executorId)) + assert(!blacklistTracker.isNodeBlacklisted("host1")) + + // Fail the task with fetch failure + tsm.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, + FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + + assert(blacklistTracker.isNodeBlacklisted("host1")) + } + + test("update blacklist before adding pending task to avoid race condition") { + // When a task fails, it should apply the blacklist policy prior to + // retrying the task otherwise there's a race condition where run on + // the same executor that it was intended to be black listed from. + val conf = new SparkConf(). + set(config.BLACKLIST_ENABLED, true) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + val exec = "executor1" + val host = "host1" + val exec2 = "executor2" + val host2 = "host2" + sched = new FakeTaskScheduler(sc, (exec, host), (exec2, host2)) + val taskSet = FakeTask.createTaskSet(1) + + val clock = new ManualClock + val mockListenerBus = mock(classOf[LiveListenerBus]) + val blacklistTracker = new BlacklistTracker(mockListenerBus, conf, None, clock) + val taskSetManager = new TaskSetManager(sched, taskSet, 1, Some(blacklistTracker)) + val taskSetManagerSpy = spy(taskSetManager) + + val taskDesc = taskSetManagerSpy.resourceOffer(exec, host, TaskLocality.ANY) + + // Assert the task has been black listed on the executor it was last executed on. + when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + val task = invocationOnMock.getArgumentAt(0, classOf[Int]) + assert(taskSetManager.taskSetBlacklistHelperOpt.get. + isExecutorBlacklistedForTask(exec, task)) + } + } + ) + + // Simulate a fake exception + val e = new ExceptionFailure("a", "b", Array(), "c", None) + taskSetManagerSpy.handleFailedTask(taskDesc.get.taskId, TaskState.FAILED, e) + + verify(taskSetManagerSpy, times(1)).addPendingTask(anyInt()) + } + + test("SPARK-21563 context's added jars shouldn't change mid-TaskSet") { + sc = new SparkContext("local", "test") + val addedJarsPreTaskSet = Map[String, Long](sc.addedJars.toSeq: _*) + assert(addedJarsPreTaskSet.size === 0) + + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet1 = FakeTask.createTaskSet(3) + val manager1 = new TaskSetManager(sched, taskSet1, MAX_TASK_FAILURES, clock = new ManualClock) + + // all tasks from the first taskset have the same jars + val taskOption1 = manager1.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption1.get.addedJars === addedJarsPreTaskSet) + val taskOption2 = manager1.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption2.get.addedJars === addedJarsPreTaskSet) + + // even with a jar added mid-TaskSet + val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar") + sc.addJar(jarPath.toString) + val addedJarsMidTaskSet = Map[String, Long](sc.addedJars.toSeq: _*) + assert(addedJarsPreTaskSet !== addedJarsMidTaskSet) + val taskOption3 = manager1.resourceOffer("exec1", "host1", NO_PREF) + // which should have the old version of the jars list + assert(taskOption3.get.addedJars === addedJarsPreTaskSet) + + // and then the jar does appear in the next TaskSet + val taskSet2 = FakeTask.createTaskSet(1) + val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, clock = new ManualClock) + + val taskOption4 = manager2.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption4.get.addedJars === addedJarsMidTaskSet) } private def createTaskResult( diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 608052f5ed85..78f618f8a216 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -130,6 +130,7 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { val conf = createConf() val key = createKey(conf) val file = Files.createTempFile("crypto", ".test").toFile() + file.deleteOnExit() val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key) try { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala index 64be96627614..a1cf3570a7a6 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala @@ -78,10 +78,10 @@ class KryoBenchmark extends SparkFunSuite { sum } } - basicTypes("Int", Random.nextInt) - basicTypes("Long", Random.nextLong) - basicTypes("Float", Random.nextFloat) - basicTypes("Double", Random.nextDouble) + basicTypes("Int", () => Random.nextInt()) + basicTypes("Long", () => Random.nextLong()) + basicTypes("Float", () => Random.nextFloat()) + basicTypes("Double", () => Random.nextDouble()) // Benchmark Array of Primitives val arrayCount = 10000 @@ -101,10 +101,10 @@ class KryoBenchmark extends SparkFunSuite { sum } } - basicTypeArray("Int", Random.nextInt) - basicTypeArray("Long", Random.nextLong) - basicTypeArray("Float", Random.nextFloat) - basicTypeArray("Double", Random.nextDouble) + basicTypeArray("Int", () => Random.nextInt()) + basicTypeArray("Long", () => Random.nextLong()) + basicTypeArray("Float", () => Random.nextFloat()) + basicTypeArray("Double", () => Random.nextDouble()) // Benchmark Maps val mapsCount = 1000 diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 21251f0b9376..cf01f79f4909 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.LocalSparkContext +import org.apache.spark.LocalSparkContext._ import org.apache.spark.SparkContext import org.apache.spark.SparkException @@ -32,9 +32,9 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "1m") - val sc = new SparkContext("local", "test", conf) - intercept[SparkException](sc.parallelize(x).collect()) - LocalSparkContext.stop(sc) + withSpark(new SparkContext("local", "test", conf)) { sc => + intercept[SparkException](sc.parallelize(x).collect()) + } } test("kryo with resizable output buffer should succeed on large array") { @@ -42,8 +42,8 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "2m") - val sc = new SparkContext("local", "test", conf) - assert(sc.parallelize(x).collect() === x) - LocalSparkContext.stop(sc) + withSpark(new SparkContext("local", "test", conf)) { sc => + assert(sc.parallelize(x).collect() === x) + } } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 7c3922e47fbb..eaec098b8d78 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -276,7 +276,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("kryo with collect for specialized tuples") { - assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) + assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === ((1, 11))) } test("kryo with SerializableHyperLogLog") { @@ -475,7 +475,7 @@ class KryoSerializerAutoResetDisabledSuite extends SparkFunSuite with SharedSpar val deserializationStream = serInstance.deserializeStream(new ByteArrayInputStream(worldWorld)) assert(deserializationStream.readValue[Any]() === world) deserializationStream.close() - assert(serInstance.deserialize[Any](helloHello) === (hello, hello)) + assert(serInstance.deserialize[Any](helloHello) === ((hello, hello))) } } diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala index 1bfb0c1547ec..82bd7c4ff660 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -31,7 +31,7 @@ class AllStagesResourceSuite extends SparkFunSuite { val tasks = new LinkedHashMap[Long, TaskUIData] taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => tasks(idx.toLong) = TaskUIData( - new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None) + new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false)) } val stageUiData = new StageUIData() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index 89ed031b6fcd..f0c521b00b58 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.UUID + import org.apache.spark.SparkFunSuite class BlockIdSuite extends SparkFunSuite { @@ -67,6 +69,32 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("shuffle data") { + val id = ShuffleDataBlockId(4, 5, 6) + assertSame(id, ShuffleDataBlockId(4, 5, 6)) + assertDifferent(id, ShuffleDataBlockId(6, 5, 6)) + assert(id.name === "shuffle_4_5_6.data") + assert(id.asRDDId === None) + assert(id.shuffleId === 4) + assert(id.mapId === 5) + assert(id.reduceId === 6) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle index") { + val id = ShuffleIndexBlockId(7, 8, 9) + assertSame(id, ShuffleIndexBlockId(7, 8, 9)) + assertDifferent(id, ShuffleIndexBlockId(9, 8, 9)) + assert(id.name === "shuffle_7_8_9.index") + assert(id.asRDDId === None) + assert(id.shuffleId === 7) + assert(id.mapId === 8) + assert(id.reduceId === 9) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("broadcast") { val id = BroadcastBlockId(42) assertSame(id, BroadcastBlockId(42)) @@ -101,6 +129,30 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("temp local") { + val id = TempLocalBlockId(new UUID(5, 2)) + assertSame(id, TempLocalBlockId(new UUID(5, 2))) + assertDifferent(id, TempLocalBlockId(new UUID(5, 3))) + assert(id.name === "temp_local_00000000-0000-0005-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 5) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + + test("temp shuffle") { + val id = TempShuffleBlockId(new UUID(1, 2)) + assertSame(id, TempShuffleBlockId(new UUID(1, 2))) + assertDifferent(id, TempShuffleBlockId(new UUID(1, 3))) + assert(id.name === "temp_shuffle_00000000-0000-0001-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 1) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + test("test") { val id = TestBlockId("abc") assertSame(id, TestBlockId("abc")) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 1b325801e27f..917db766f7f1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -152,7 +152,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { // one should acquire the write lock. The second thread should block until the winner of the // write race releases its lock. val winningFuture: Future[Boolean] = - Await.ready(Future.firstCompletedOf(Seq(lock1Future, lock2Future)), 1.seconds) + ThreadUtils.awaitReady(Future.firstCompletedOf(Seq(lock1Future, lock2Future)), 1.seconds) assert(winningFuture.value.get.get) val winningTID = blockInfoManager.get("block").get.writerTask assert(winningTID === 1 || winningTID === 2) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c100803279ea..dd61dcd11bcd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -100,7 +100,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite sc = new SparkContext("local", "test", conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) allStores.clear() } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a8b960489983..cfe89fde63f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,30 +17,36 @@ package org.apache.spark.storage +import java.io.File import java.nio.ByteBuffer +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ import scala.concurrent.Future -import scala.language.implicitConversions -import scala.language.postfixOps +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} import scala.reflect.ClassTag +import org.apache.commons.lang3.RandomUtils import org.mockito.{Matchers => mc} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.TimeLimits._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager -import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.network.{BlockDataManager, BlockTransferService, TransportContext} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} +import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} @@ -124,7 +130,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(sc.conf).thenReturn(conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -496,8 +502,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } - test("optimize a location order of blocks") { - val localHost = Utils.localHostName() + test("optimize a location order of blocks without topology information") { + val localHost = "localhost" val otherHost = "otherHost" val bmMaster = mock(classOf[BlockManagerMaster]) val bmId1 = BlockManagerId("id1", localHost, 1) @@ -508,7 +514,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) - assert(locations.map(_.host).toSet === Set(localHost, localHost, otherHost)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) + } + + test("optimize a location order of blocks with topology information") { + val localHost = "localhost" + val otherHost = "otherHost" + val localRack = "localRack" + val otherRack = "otherRack" + + val bmMaster = mock(classOf[BlockManagerMaster]) + val bmId1 = BlockManagerId("id1", localHost, 1, Some(localRack)) + val bmId2 = BlockManagerId("id2", localHost, 2, Some(localRack)) + val bmId3 = BlockManagerId("id3", otherHost, 3, Some(otherRack)) + val bmId4 = BlockManagerId("id4", otherHost, 4, Some(otherRack)) + val bmId5 = BlockManagerId("id5", otherHost, 5, Some(localRack)) + when(bmMaster.getLocations(mc.any[BlockId])) + .thenReturn(Seq(bmId1, bmId2, bmId5, bmId3, bmId4)) + + val blockManager = makeBlockManager(128, "exec", bmMaster) + blockManager.blockManagerId = + BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) + val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) + val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) + assert(locations.flatMap(_.topologyInfo) + === Seq(localRack, localRack, localRack, otherRack, otherRack)) } test("SPARK-9591: getRemoteBytes from another location when Exception throw") { @@ -891,8 +922,38 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("turn off updated block statuses") { + val conf = new SparkConf() + conf.set(TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES, false) + store = makeBlockManager(12000, testConf = Some(conf)) + + store.registerTask(0) + val list = List.fill(2)(new Array[Byte](2000)) + + def getUpdatedBlocks(task: => Unit): Seq[(BlockId, BlockStatus)] = { + val context = TaskContext.empty() + try { + TaskContext.setTaskContext(context) + task + } finally { + TaskContext.unset() + } + context.taskMetrics.updatedBlockStatuses + } + + // 1 updated block (i.e. list1) + val updatedBlocks1 = getUpdatedBlocks { + store.putIterator( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } + assert(updatedBlocks1.size === 0) + } + + test("updated block statuses") { - store = makeBlockManager(12000) + val conf = new SparkConf() + conf.set(TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES, true) + store = makeBlockManager(12000, testConf = Some(conf)) store.registerTask(0) val list = List.fill(2)(new Array[Byte](2000)) val bigList = List.fill(8)(new Array[Byte](2000)) @@ -1255,6 +1316,61 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("item").isEmpty) } + test("SPARK-20640: Shuffle registration timeout and maxAttempts conf are working") { + val tryAgainMsg = "test_spark_20640_try_again" + // a server which delays response 50ms and must try twice for success. + def newShuffleServer(port: Int): (TransportServer, Int) = { + val attempts = new mutable.HashMap[String, Int]() + val handler = new NoOpRpcHandler { + override def receive( + client: TransportClient, + message: ByteBuffer, + callback: RpcResponseCallback): Unit = { + val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) + msgObj match { + case exec: RegisterExecutor => + Thread.sleep(50) + val attempt = attempts.getOrElse(exec.execId, 0) + 1 + attempts(exec.execId) = attempt + if (attempt < 2) { + callback.onFailure(new Exception(tryAgainMsg)) + return + } + callback.onSuccess(ByteBuffer.wrap(new Array[Byte](0))) + } + } + } + + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 0) + val transCtx = new TransportContext(transConf, handler, true) + (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) + } + val candidatePort = RandomUtils.nextInt(1024, 65536) + val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, + newShuffleServer, conf, "ShuffleServer") + + conf.set("spark.shuffle.service.enabled", "true") + conf.set("spark.shuffle.service.port", shufflePort.toString) + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") + var e = intercept[SparkException]{ + makeBlockManager(8000, "executor1") + }.getMessage + assert(e.contains("TimeoutException")) + + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") + e = intercept[SparkException]{ + makeBlockManager(8000, "executor2") + }.getMessage + assert(e.contains(tryAgainMsg)) + + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") + makeBlockManager(8000, "executor3") + server.close() + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 @@ -1265,7 +1381,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit = { + listener: BlockFetchingListener, + tempShuffleFileManager: TempShuffleFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index dfecd04c1b96..4000218e71a8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.language.implicitConversions import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index bbfd6df3b699..7859b0bba2b4 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import scala.language.reflectiveCalls - import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index bfb3ac4c15bc..cea55012c1de 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -116,6 +116,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.revertPartialWritesAndClose() assert(firstSegment.length === file.length()) assert(writeMetrics.bytesWritten === file.length()) + assert(writeMetrics.recordsWritten == 1) } test("calling revertPartialWritesAndClose() after commit() should have no effect") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 67fc084e8a13..7258fdf5efc0 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -26,8 +26,8 @@ import io.netty.channel.FileRegion import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} import org.apache.spark.security.CryptoStreamUtils -import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.Utils +import org.apache.spark.util.io.ChunkedByteBuffer class DiskStoreSuite extends SparkFunSuite { @@ -50,18 +50,18 @@ class DiskStoreSuite extends SparkFunSuite { val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) - val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer + val mapped = diskStoreMapped.getBytes(blockId).toByteBuffer() assert(diskStoreMapped.remove(blockId)) val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) - val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer + val notMapped = diskStoreNotMapped.getBytes(blockId).toByteBuffer() // Not possible to do isInstanceOf due to visibility of HeapByteBuffer - assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), + assert(notMapped.getClass.getName.endsWith("HeapByteBuffer"), "Expected HeapByteBuffer for un-mapped read") - assert(mapped.getChunks().forall(_.isInstanceOf[MappedByteBuffer]), + assert(mapped.isInstanceOf[MappedByteBuffer], "Expected MappedByteBuffer for mapped read") def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = { @@ -70,8 +70,8 @@ class DiskStoreSuite extends SparkFunSuite { array } - assert(Arrays.equals(mapped.toArray, bytes)) - assert(Arrays.equals(notMapped.toArray, bytes)) + assert(Arrays.equals(new ChunkedByteBuffer(mapped).toArray, bytes)) + assert(Arrays.equals(new ChunkedByteBuffer(notMapped).toArray, bytes)) } test("block size tracking") { @@ -92,6 +92,44 @@ class DiskStoreSuite extends SparkFunSuite { assert(diskStore.getSize(blockId) === 0L) } + test("blocks larger than 2gb") { + val conf = new SparkConf() + .set("spark.storage.memoryMapLimitForTests", "10k" ) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val arr = new Array[Byte](1024) + for { + _ <- 0 until 20 + } { + val buf = ByteBuffer.wrap(arr) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + } + + val blockData = diskStore.getBytes(blockId) + assert(blockData.size == 20 * 1024) + + val chunkedByteBuffer = blockData.toChunkedByteBuffer(ByteBuffer.allocate) + val chunks = chunkedByteBuffer.chunks + assert(chunks.size === 2) + for (chunk <- chunks) { + assert(chunk.limit === 10 * 1024) + } + + val e = intercept[IllegalArgumentException]{ + blockData.toByteBuffer() + } + + assert(e.getMessage === + s"requirement failed: can't create a byte buffer of size ${blockData.size}" + + " since it exceeds 10.0 KB.") + } + test("block data encryption") { val testDir = Utils.createTempDir() val testData = new Array[Byte](128 * 1024) diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index f7b3a2754f0e..6883eb211efd 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -37,27 +37,50 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { Utils.clearLocalRootDirs() } + private def assumeNonExistentAndNotCreatable(f: File): Unit = { + try { + assume(!f.exists() && !f.mkdirs()) + } finally { + Utils.deleteRecursively(f) + } + } + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 - assert(!new File("/NONEXISTENT_PATH").exists()) + val f = new File("/NONEXISTENT_PATH") + assumeNonExistentAndNotCreatable(f) + val conf = new SparkConf(false) .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") assert(new File(Utils.getLocalDir(conf)).exists()) + + // This directory should not be created. + assert(!f.exists()) } test("SPARK_LOCAL_DIRS override also affects driver") { - // Regression test for SPARK-2975 - assert(!new File("/NONEXISTENT_PATH").exists()) + // Regression test for SPARK-2974 + val f = new File("/NONEXISTENT_PATH") + assumeNonExistentAndNotCreatable(f) + // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) .set("spark.local.dir", "/NONEXISTENT_PATH") assert(new File(Utils.getLocalDir(conf)).exists()) + + // This directory should not be created. + assert(!f.exists()) } test("Utils.getLocalDir() throws an exception if any temporary directory cannot be retrieved") { val path1 = "/NONEXISTENT_PATH_ONE" val path2 = "/NONEXISTENT_PATH_TWO" + val f1 = new File(path1) + val f2 = new File(path2) + assumeNonExistentAndNotCreatable(f1) + assumeNonExistentAndNotCreatable(f2) + assert(!new File(path1).exists()) assert(!new File(path2).exists()) val conf = new SparkConf(false).set("spark.local.dir", s"$path1,$path2") @@ -67,5 +90,9 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { // If any temporary directory could not be retrieved under the given paths above, it should // throw an exception with the message that includes the paths. assert(message.contains(s"$path1,$path2")) + + // These directories should not be created. + assert(!f1.exists()) + assert(!f2.exists()) } } diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index 9929ea033a99..7274072e5049 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -407,4 +407,123 @@ class MemoryStoreSuite }) assert(memoryStore.getSize(blockId) === 10000) } + + test("SPARK-22083: Release all locks in evictBlocksToFreeSpace") { + // Setup a memory store with many blocks cached, and then one request which leads to multiple + // blocks getting evicted. We'll make the eviction throw an exception, and make sure that + // all locks are released. + val ct = implicitly[ClassTag[Array[Byte]]] + val numInitialBlocks = 10 + val memStoreSize = 100 + val bytesPerSmallBlock = memStoreSize / numInitialBlocks + def testFailureOnNthDrop(numValidBlocks: Int, readLockAfterDrop: Boolean): Unit = { + val tc = TaskContext.empty() + val memManager = new StaticMemoryManager(conf, Long.MaxValue, memStoreSize, numCores = 1) + val blockInfoManager = new BlockInfoManager + blockInfoManager.registerTask(tc.taskAttemptId) + var droppedSoFar = 0 + val blockEvictionHandler = new BlockEvictionHandler { + var memoryStore: MemoryStore = _ + + override private[storage] def dropFromMemory[T: ClassTag]( + blockId: BlockId, + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { + if (droppedSoFar < numValidBlocks) { + droppedSoFar += 1 + memoryStore.remove(blockId) + if (readLockAfterDrop) { + // for testing purposes, we act like another thread gets the read lock on the new + // block + StorageLevel.DISK_ONLY + } else { + StorageLevel.NONE + } + } else { + throw new RuntimeException(s"Mock error dropping block $droppedSoFar") + } + } + } + val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memManager, + blockEvictionHandler) { + override def afterDropAction(blockId: BlockId): Unit = { + if (readLockAfterDrop) { + // pretend that we get a read lock on the block (now on disk) in another thread + TaskContext.setTaskContext(tc) + blockInfoManager.lockForReading(blockId) + TaskContext.unset() + } + } + } + + blockEvictionHandler.memoryStore = memoryStore + memManager.setMemoryStore(memoryStore) + + // Put in some small blocks to fill up the memory store + val initialBlocks = (1 to numInitialBlocks).map { id => + val blockId = BlockId(s"rdd_1_$id") + val blockInfo = new BlockInfo(StorageLevel.MEMORY_ONLY, ct, tellMaster = false) + val initialWriteLock = blockInfoManager.lockNewBlockForWriting(blockId, blockInfo) + assert(initialWriteLock) + val success = memoryStore.putBytes(blockId, bytesPerSmallBlock, MemoryMode.ON_HEAP, () => { + new ChunkedByteBuffer(ByteBuffer.allocate(bytesPerSmallBlock)) + }) + assert(success) + blockInfoManager.unlock(blockId, None) + } + assert(blockInfoManager.size === numInitialBlocks) + + + // Add one big block, which will require evicting everything in the memorystore. However our + // mock BlockEvictionHandler will throw an exception -- make sure all locks are cleared. + val largeBlockId = BlockId(s"rdd_2_1") + val largeBlockInfo = new BlockInfo(StorageLevel.MEMORY_ONLY, ct, tellMaster = false) + val initialWriteLock = blockInfoManager.lockNewBlockForWriting(largeBlockId, largeBlockInfo) + assert(initialWriteLock) + if (numValidBlocks < numInitialBlocks) { + val exc = intercept[RuntimeException] { + memoryStore.putBytes(largeBlockId, memStoreSize, MemoryMode.ON_HEAP, () => { + new ChunkedByteBuffer(ByteBuffer.allocate(memStoreSize)) + }) + } + assert(exc.getMessage().startsWith("Mock error dropping block"), exc) + // BlockManager.doPut takes care of releasing the lock for the newly written block -- not + // testing that here, so do it manually + blockInfoManager.removeBlock(largeBlockId) + } else { + memoryStore.putBytes(largeBlockId, memStoreSize, MemoryMode.ON_HEAP, () => { + new ChunkedByteBuffer(ByteBuffer.allocate(memStoreSize)) + }) + // BlockManager.doPut takes care of releasing the lock for the newly written block -- not + // testing that here, so do it manually + blockInfoManager.unlock(largeBlockId) + } + + val largeBlockInMemory = if (numValidBlocks == numInitialBlocks) 1 else 0 + val expBlocks = numInitialBlocks + + (if (readLockAfterDrop) 0 else -numValidBlocks) + + largeBlockInMemory + assert(blockInfoManager.size === expBlocks) + + val blocksStillInMemory = blockInfoManager.entries.filter { case (id, info) => + assert(info.writerTask === BlockInfo.NO_WRITER, id) + // in this test, all the blocks in memory have no reader, but everything dropped to disk + // had another thread read the block. We shouldn't lose the other thread's reader lock. + if (memoryStore.contains(id)) { + assert(info.readerCount === 0, id) + true + } else { + assert(info.readerCount === 1, id) + false + } + } + assert(blocksStillInMemory.size === + (numInitialBlocks - numValidBlocks + largeBlockInMemory)) + } + + Seq(0, 3, numInitialBlocks).foreach { failAfterDropping => + Seq(true, false).foreach { readLockAfterDropping => + testFailureOnNthDrop(failAfterDropping, readLockAfterDropping) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala index 3050f9a25023..535105379963 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite try { TaskContext.setTaskContext(TaskContext.empty()) val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None) Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() Mockito.verifyNoMoreInteractions(memoryStore) } finally { diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala index 4253cc8ca4cd..cbc903f17ad7 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import org.mockito.Matchers import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode.ON_HEAP diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e56e440380a5..c371cbcf8dff 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} +import java.util.UUID import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global @@ -32,9 +33,10 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { @@ -44,7 +46,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -106,6 +109,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, + Int.MaxValue, true) // 3 local blocks fetched in initialization @@ -134,7 +139,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -153,7 +158,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -181,6 +187,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, + Int.MaxValue, true) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -192,7 +200,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() - taskContext.markTaskCompleted() + taskContext.markTaskCompleted(None) verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() // The 3rd block should not be retained because the iterator is already in zombie state @@ -218,7 +226,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -246,6 +255,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -281,7 +292,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -309,6 +321,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -318,7 +332,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -359,7 +374,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -387,6 +403,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, + Int.MaxValue, false) // Continue only after the mock calls onBlockFetchFailure @@ -401,4 +419,66 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(id3 === ShuffleBlockId(0, 2, 0)) } + test("Blocks should be shuffled to disk when size of the request is above the" + + " threshold(maxReqSizeShuffleToMem).") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + val diskBlockManager = mock(classOf[DiskBlockManager]) + val tmpDir = Utils.createTempDir() + doReturn{ + val blockId = TempLocalBlockId(UUID.randomUUID()) + (blockId, new File(tmpDir, blockId.name)) + }.when(diskBlockManager).createTempLocalBlock() + doReturn(diskBlockManager).when(blockManager).diskBlockManager + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) + val transfer = mock(classOf[BlockTransferService]) + var tempShuffleFileManager: TempShuffleFileManager = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager] + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + } + } + }) + + def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the + // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks + // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. + new ShuffleBlockFetcherIterator( + TaskContext.empty(), + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + maxBytesInFlight = Int.MaxValue, + maxReqsInFlight = Int.MaxValue, + maxBlocksInFlightPerAddress = Int.MaxValue, + maxReqSizeShuffleToMem = 200, + detectCorrupt = true) + } + + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + fetchShuffleBlock(blocksByAddress1) + // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch + // shuffle block to disk. + assert(tempShuffleFileManager == null) + + val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + fetchShuffleBlock(blocksByAddress2) + // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch + // shuffle block to disk. + assert(tempShuffleFileManager != null) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index bdd148875e38..267c8dc1bd75 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -320,12 +320,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B eventually(timeout(5 seconds), interval(50 milliseconds)) { goToUi(sc, "/jobs") find(cssSelector(".stage-progress-cell")).get.text should be ("2/2 (1 failed)") - // Ideally, the following test would pass, but currently we overcount completed tasks - // if task recomputations occur: - // find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") - // Instead, we guarantee that the total number of tasks is always correct, while the number - // of completed tasks may be higher: - find(cssSelector(".progress-cell .progress")).get.text should be ("3/2 (1 failed)") + find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") } val jobJson = getJson(sc.ui.get, "jobs") (jobJson \ "numTasks").extract[Int]should be (2) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 0c3d4caeeabf..36ea3799afdf 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -200,36 +200,34 @@ class UISuite extends SparkFunSuite { } test("verify proxy rewrittenURI") { - val prefix = "/proxy/worker-id" + val prefix = "/worker-id" val target = "http://localhost:8081" - val path = "/proxy/worker-id/json" + val path = "/worker-id/json" var rewrittenURI = JettyUtils.createProxyURI(prefix, target, path, null) assert(rewrittenURI.toString() === "http://localhost:8081/json") rewrittenURI = JettyUtils.createProxyURI(prefix, target, path, "test=done") assert(rewrittenURI.toString() === "http://localhost:8081/json?test=done") - rewrittenURI = JettyUtils.createProxyURI(prefix, target, "/proxy/worker-id", null) + rewrittenURI = JettyUtils.createProxyURI(prefix, target, "/worker-id", null) assert(rewrittenURI.toString() === "http://localhost:8081") - rewrittenURI = JettyUtils.createProxyURI(prefix, target, "/proxy/worker-id/test%2F", null) + rewrittenURI = JettyUtils.createProxyURI(prefix, target, "/worker-id/test%2F", null) assert(rewrittenURI.toString() === "http://localhost:8081/test%2F") - rewrittenURI = JettyUtils.createProxyURI(prefix, target, "/proxy/worker-id/%F0%9F%98%84", null) + rewrittenURI = JettyUtils.createProxyURI(prefix, target, "/worker-id/%F0%9F%98%84", null) assert(rewrittenURI.toString() === "http://localhost:8081/%F0%9F%98%84") - rewrittenURI = JettyUtils.createProxyURI(prefix, target, "/proxy/worker-noid/json", null) + rewrittenURI = JettyUtils.createProxyURI(prefix, target, "/worker-noid/json", null) assert(rewrittenURI === null) } test("verify rewriting location header for reverse proxy") { val clientRequest = mock(classOf[HttpServletRequest]) var headerValue = "http://localhost:4040/jobs" - val prefix = "/proxy/worker-id" val targetUri = URI.create("http://localhost:4040") when(clientRequest.getScheme()).thenReturn("http") when(clientRequest.getHeader("host")).thenReturn("localhost:8080") - var newHeader = JettyUtils.createProxyLocationHeader( - prefix, headerValue, clientRequest, targetUri) + when(clientRequest.getPathInfo()).thenReturn("/proxy/worker-id/jobs") + var newHeader = JettyUtils.createProxyLocationHeader(headerValue, clientRequest, targetUri) assert(newHeader.toString() === "http://localhost:8080/proxy/worker-id/jobs") headerValue = "http://localhost:4041/jobs" - newHeader = JettyUtils.createProxyLocationHeader( - prefix, headerValue, clientRequest, targetUri) + newHeader = JettyUtils.createProxyLocationHeader(headerValue, clientRequest, targetUri) assert(newHeader === null) } diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index c770fd5da76f..423daacc0f5a 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -133,6 +133,45 @@ class UIUtilsSuite extends SparkFunSuite { assert(decoded2 === decodeURLParameter(decoded2)) } + test("SPARK-20393: Prevent newline characters in parameters.") { + val encoding = "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b" + val stripEncoding = "Encoding:base64PGh0bWw%2bjcmlwdD48L2h0bWw%2b" + + assert(stripEncoding === stripXSS(encoding)) + } + + test("SPARK-20393: Prevent script from parameters running on page.") { + val scriptAlert = """>"'> - UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) + val summary: NodeSeq = +
+
    + { + if (listener.getRunningExecutions.nonEmpty) { +
  • + Running Queries: + {listener.getRunningExecutions.size} +
  • + } + } + { + if (listener.getCompletedExecutions.nonEmpty) { +
  • + Completed Queries: + {listener.getCompletedExecutions.size} +
  • + } + } + { + if (listener.getFailedExecutions.nonEmpty) { +
  • + Failed Queries: + {listener.getFailedExecutions.size} +
  • + } + } +
+
+ UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) } } @@ -88,13 +117,19 @@ private[ui] abstract class ExecutionTable( val duration = executionUIData.completionTime.getOrElse(currentTime) - submissionTime val runningJobs = executionUIData.runningJobs.map { jobId => - {jobId.toString}
+ + [{jobId.toString}] + } val succeededJobs = executionUIData.succeededJobs.sorted.map { jobId => - {jobId.toString}
+ + [{jobId.toString}] + } val failedJobs = executionUIData.failedJobs.sorted.map { jobId => - {jobId.toString}
+ + [{jobId.toString}] + } @@ -177,7 +212,7 @@ private[ui] class RunningExecutionTable( showFailedJobs = true) { override protected def header: Seq[String] = - baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs") + baseHeader ++ Seq("Running Job IDs", "Succeeded Job IDs", "Failed Job IDs") } private[ui] class CompletedExecutionTable( @@ -195,7 +230,7 @@ private[ui] class CompletedExecutionTable( showSucceededJobs = true, showFailedJobs = false) { - override protected def header: Seq[String] = baseHeader ++ Seq("Jobs") + override protected def header: Seq[String] = baseHeader ++ Seq("Job IDs") } private[ui] class FailedExecutionTable( @@ -214,5 +249,5 @@ private[ui] class FailedExecutionTable( showFailedJobs = true) { override protected def header: Seq[String] = - baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs") + baseHeader ++ Seq("Succeeded Job IDs", "Failed Job IDs") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 23fc0bd0bce1..460fc946c3e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -29,7 +29,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging private val listener = parent.listener override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized { - val parameterExecutionId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id")) require(parameterExecutionId != null && parameterExecutionId.nonEmpty, "Missing execution id parameter") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index b4a91230a001..8c27af374feb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -255,10 +255,8 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging { // heartbeat reports } case None => - // TODO Now just set attemptId to 0. Should fix here when we can get the attempt - // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, accumulatorUpdates) + finished = finishTask, accumulatorUpdates) } } case None => @@ -478,10 +476,11 @@ private[ui] class SQLStageMetrics( val stageAttemptId: Long, val taskIdToMetricUpdates: mutable.HashMap[Long, SQLTaskMetrics] = mutable.HashMap.empty) + +// TODO Should add attemptId here when we can get it from SparkListenerExecutorMetricsUpdate /** * Store all accumulatorUpdates for a Spark task. */ private[ui] class SQLTaskMetrics( - val attemptId: Long, // TODO not used yet var finished: Boolean, var accumulatorUpdates: Seq[(Long, Any)]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 9d4ebcce4d10..884f945815e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -113,7 +113,7 @@ object SparkPlanGraph { } val node = new SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), planInfo.nodeName, - planInfo.simpleString, planInfo.metadata, metrics) + planInfo.simpleString, metrics) if (subgraph == null) { nodes += node } else { @@ -143,7 +143,6 @@ private[ui] class SparkPlanGraphNode( val id: Long, val name: String, val desc: String, - val metadata: Map[String, String], val metrics: Seq[SQLPlanMetric]) { def makeDotNode(metricsValue: Map[Long, String]): String = { @@ -177,7 +176,7 @@ private[ui] class SparkPlanGraphCluster( desc: String, val nodes: mutable.ArrayBuffer[SparkPlanGraphNode], metrics: Seq[SQLPlanMetric]) - extends SparkPlanGraphNode(id, name, desc, Map.empty, metrics) { + extends SparkPlanGraphNode(id, name, desc, metrics) { override def makeDotNode(metricsValue: Map[Long, String]): String = { val duration = metrics.filter(_.name.startsWith(WholeStageCodegenExec.PIPELINE_DURATION_METRIC)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index c9f5d3b3d92d..bc141b36e63b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -26,17 +26,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ /** * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a - * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, - * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way + * that reduces the processing of a [[AggregateWindowFunction]] to processing the underlying * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. * * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions - * require the size of the partition processed, this value is exposed to them when the processor is - * constructed. + * require the size of the partition processed and this value is exposed to them when + * the processor is constructed. * * Processing of distinct aggregates is currently not supported. * - * The implementation is split into an object which takes care of construction, and a the actual + * The implementation is split into an object which takes care of construction, and the actual * processor class. */ private[window] object AggregateProcessor { @@ -90,7 +90,7 @@ private[window] object AggregateProcessor { updateExpressions ++= noOps evaluateExpressions += imperative case other => - sys.error(s"Unsupported Aggregate Function: $other") + sys.error(s"Unsupported aggregate function: $other") } // Create the projections. @@ -154,6 +154,7 @@ private[window] final class AggregateProcessor( } /** Evaluate buffer. */ - def evaluate(target: InternalRow): Unit = - evaluateProjection.target(target)(buffer) + def evaluate(target: InternalRow): Unit = { + evaluateProjection.target(target)(buffer) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 950a6794a74a..800a2ea3f399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -25,8 +25,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType} /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -109,61 +110,70 @@ case class WindowExec( * * This method uses Code Generation. It can only be used on the executor side. * - * @param frameType to evaluate. This can either be Row or Range based. - * @param offset with respect to the row. + * @param frame to evaluate. This can either be a Row or Range frame. + * @param bound with respect to the row. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { - frameType match { - case RangeFrame => - val (exprs, current, bound) = if (offset == 0) { - // Use the entire order expression when the offset is 0. - val exprs = orderSpec.map(_.child) - val buildProjection = () => newMutableProjection(exprs, child.output) - (orderSpec, buildProjection(), buildProjection()) - } else if (orderSpec.size == 1) { - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output) - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => -offset - case Ascending => offset - } - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output) - (sortExpr :: Nil, current, bound) - } else { - sys.error("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") + private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + (frame, bound) match { + case (RowFrame, CurrentRow) => + RowBoundOrdering(0) + + case (RowFrame, IntegerLiteral(offset)) => + RowBoundOrdering(offset) + + case (RangeFrame, CurrentRow) => + val ordering = newOrdering(orderSpec, child.output) + RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) + + case (RangeFrame, offset: Expression) if orderSpec.size == 1 => + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output) + + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => UnaryMinus(offset) + case Ascending => offset + } + + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = (expr.dataType, boundOffset.dataType) match { + case (DateType, IntegerType) => DateAdd(expr, boundOffset) + case (TimestampType, CalendarIntervalType) => + TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone)) + case (a, b) if a== b => Add(expr, boundOffset) } + val bound = newMutableProjection(boundExpr :: Nil, child.output) + // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). - val sortExprs = exprs.zipWithIndex.map { case (e, i) => - SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) - } - val ordering = newOrdering(sortExprs, Nil) + val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil + val ordering = newOrdering(boundSortExprs, Nil) RangeBoundOrdering(ordering, current, bound) - case RowFrame => RowBoundOrdering(offset) + + case (RangeFrame, _) => + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") } } /** - * Collection containing an entry for each window frame to process. Each entry contains a frames' - * WindowExpressions and factory function for the WindowFrameFunction. + * Collection containing an entry for each window frame to process. Each entry contains a frame's + * [[WindowExpression]]s and factory function for the WindowFrameFunction. */ private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type FrameKey = (String, FrameType, Expression, Expression) type ExpressionBuffer = mutable.Buffer[Expression] val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] // Add a function and its function to the map for a given frame. def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val key = (tpe, fr.frameType, fr.lower, fr.upper) val (es, fns) = framedFunctions.getOrElseUpdate( key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) es += e @@ -203,7 +213,7 @@ case class WindowExec( // Create the factory val factory = key match { // Offset Frame - case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + case ("OFFSET", _, IntegerLiteral(offset), _) => target: InternalRow => new OffsetWindowFunctionFrame( target, @@ -215,38 +225,38 @@ case class WindowExec( newMutableProjection(expressions, schema, subexpressionEliminationEnabled), offset) + // Entire Partition Frame. + case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => + target: InternalRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + // Growing Frame. - case ("AGGREGATE", frameType, None, Some(high)) => + case ("AGGREGATE", frameType, UnboundedPreceding, upper) => target: InternalRow => { new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, high)) + createBoundOrdering(frameType, upper)) } // Shrinking Frame. - case ("AGGREGATE", frameType, Some(low), None) => + case ("AGGREGATE", frameType, lower, UnboundedFollowing) => target: InternalRow => { new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, low)) + createBoundOrdering(frameType, lower)) } // Moving Frame. - case ("AGGREGATE", frameType, Some(low), Some(high)) => + case ("AGGREGATE", frameType, lower, upper) => target: InternalRow => { new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, low), - createBoundOrdering(frameType, high)) - } - - // Entire Partition Frame. - case ("AGGREGATE", frameType, None, None) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) + createBoundOrdering(frameType, lower), + createBoundOrdering(frameType, upper)) } } @@ -282,6 +292,7 @@ case class WindowExec( // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold // Start processing. @@ -312,7 +323,8 @@ case class WindowExec( val inputFields = child.output.length val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index af2b4fb92062..156002ef58fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -195,15 +195,6 @@ private[window] final class SlidingWindowFunctionFrame( override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 - // Add all rows to the buffer for which the input row value is equal to or less than - // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { - buffer.add(nextRow.copy()) - nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputHighIndex += 1 - bufferUpdated = true - } - // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { @@ -212,6 +203,19 @@ private[window] final class SlidingWindowFunctionFrame( bufferUpdated = true } + // Add all rows to the buffer for which the input row value is equal to or less than + // the output row upper bound. + while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { + if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) { + inputLowIndex += 1 + } else { + buffer.add(nextRow.copy()) + bufferUpdated = true + } + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) + inputHighIndex += 1 + } + // Only recalculate and update when the buffer changes. if (bufferUpdated) { processor.initialize(input.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index b13fe7016092..03b654f83052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -1,26 +1,25 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.Column -import org.apache.spark.sql.functions +import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.types.DataType /** @@ -35,10 +34,6 @@ import org.apache.spark.sql.types.DataType * df.select( predict(df("score")) ) * }}} * - * @note The user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. - * * @since 1.3.0 */ @InterfaceStability.Stable @@ -47,12 +42,87 @@ case class UserDefinedFunction protected[sql] ( dataType: DataType, inputTypes: Option[Seq[DataType]]) { + private var _nameOption: Option[String] = None + private var _nullable: Boolean = true + private var _deterministic: Boolean = true + + /** + * Returns true when the UDF can return a nullable value. + * + * @since 2.3.0 + */ + def nullable: Boolean = _nullable + + /** + * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same + * input. + * + * @since 2.3.0 + */ + def deterministic: Boolean = _deterministic + /** * Returns an expression that invokes the UDF, using the given arguments. * * @since 1.3.0 */ def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) + Column(ScalaUDF( + f, + dataType, + exprs.map(_.expr), + inputTypes.getOrElse(Nil), + udfName = _nameOption, + nullable = _nullable, + udfDeterministic = _deterministic)) + } + + private def copyAll(): UserDefinedFunction = { + val udf = copy() + udf._nameOption = _nameOption + udf._nullable = _nullable + udf._deterministic = _deterministic + udf + } + + /** + * Updates UserDefinedFunction with a given name. + * + * @since 2.3.0 + */ + def withName(name: String): UserDefinedFunction = { + val udf = copyAll() + udf._nameOption = Option(name) + udf + } + + /** + * Updates UserDefinedFunction to non-nullable. + * + * @since 2.3.0 + */ + def asNonNullable(): UserDefinedFunction = { + if (!nullable) { + this + } else { + val udf = copyAll() + udf._nullable = false + udf + } + } + + /** + * Updates UserDefinedFunction to nondeterministic. + * + * @since 2.3.0 + */ + def asNondeterministic(): UserDefinedFunction = { + if (!_deterministic) { + this + } else { + val udf = copyAll() + udf._deterministic = false + udf + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 00053485e614..1caa243f8d11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -75,7 +75,7 @@ object Window { } /** - * Value representing the last row in the partition, equivalent to "UNBOUNDED PRECEDING" in SQL. + * Value representing the first row in the partition, equivalent to "UNBOUNDED PRECEDING" in SQL. * This can be used to specify the frame boundaries: * * {{{ @@ -167,24 +167,24 @@ object Window { * current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `Window.currentRow` to specify special boundary values, rather than using integral - * values directly. + * and `Window.currentRow` to specify special boundary values, rather than using long values + * directly. * - * A range based boundary is based on the actual value of the ORDER BY + * A range-based boundary is based on the actual value of the ORDER BY * expression(s). An offset is used to alter the value of the ORDER BY expression, for * instance if the current order by expression has a value of 10 and the lower bound offset * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a * number of constraints on the ORDER BY expressions: there can be only one expression and this - * expression must have a numerical data type. An exception can be made when the offset is 0, - * because no value modification is needed, in this case multiple and non-numeric ORDER BY - * expression are allowed. + * expression must have a numerical data type. An exception can be made when the offset is + * unbounded, because no value modification is needed, in this case multiple and non-numeric + * ORDER BY expression are allowed. * * {{{ * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") * val byCategoryOrderedById = - * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * Window.partitionBy('category).orderBy('id).rangeBetween(Window.currentRow, 1) * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ @@ -210,6 +210,57 @@ object Window { spec.rangeBetween(start, end) } + /** + * Creates a [[WindowSpec]] with the frame boundaries defined, + * from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative to the current row. For example, "lit(0)" means + * "current row", while "lit(-1)" means one off before the current row, and "lit(5)" means the + * five off after the current row. + * + * Users should use `unboundedPreceding()`, `unboundedFollowing()`, and `currentRow()` from + * [[org.apache.spark.sql.functions]] to specify special boundary values, literals are not + * transformed to [[org.apache.spark.sql.catalyst.expressions.SpecialFrameBoundary]]s. + * + * A range-based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical/date/timestamp data type. An exception can be made when the + * offset is unbounded, because no value modification is needed, in this case multiple and + * non-numerical/date/timestamp data type ORDER BY expression are allowed. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rangeBetween(currentRow(), lit(1)) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 4| + * | 1| a| 4| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * + * @param start boundary start, inclusive. The frame is unbounded if the expression is + * [[org.apache.spark.sql.catalyst.expressions.UnboundedPreceding]]. + * @param end boundary end, inclusive. The frame is unbounded if the expression is + * [[org.apache.spark.sql.catalyst.expressions.UnboundedFollowing]]. + * @since 2.3.0 + */ + def rangeBetween(start: Column, end: Column): WindowSpec = { + spec.rangeBetween(start, end) + } + private[sql] def spec: WindowSpec = { new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 6279d48c94de..4c41aa3c5fb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.Column +import org.apache.spark.sql.{AnalysisException, Column} import org.apache.spark.sql.catalyst.expressions._ /** @@ -123,28 +123,45 @@ class WindowSpec private[sql]( */ // Note: when updating the doc for this method, also update Window.rowsBetween. def rowsBetween(start: Long, end: Long): WindowSpec = { - between(RowFrame, start, end) + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case x => throw new AnalysisException(s"Boundary start is not a valid integer: $x") + } + + val boundaryEnd = end match { + case 0 => CurrentRow + case Long.MaxValue => UnboundedFollowing + case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case x => throw new AnalysisException(s"Boundary end is not a valid integer: $x") + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(RowFrame, boundaryStart, boundaryEnd)) } /** * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative from the current row. For example, "0" means "current row", - * while "-1" means one off before the current row, and "5" means the five off after the - * current row. + * Both `start` and `end` are relative from the current row. For example, "0" means + * "current row", while "-1" means one off before the current row, and "5" means the five off + * after the current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `Window.currentRow` to specify special boundary values, rather than using integral - * values directly. + * and `Window.currentRow` to specify special boundary values, rather than using long values + * directly. * - * A range based boundary is based on the actual value of the ORDER BY + * A range-based boundary is based on the actual value of the ORDER BY * expression(s). An offset is used to alter the value of the ORDER BY expression, for * instance if the current order by expression has a value of 10 and the lower bound offset * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a * number of constraints on the ORDER BY expressions: there can be only one expression and this - * expression must have a numerical data type. An exception can be made when the offset is 0, - * because no value modification is needed, in this case multiple and non-numeric ORDER BY - * expression are allowed. + * expression must have a numerical data type. An exception can be made when the offset is + * unbounded, because no value modification is needed, in this case multiple and non-numeric + * ORDER BY expression are allowed. * * {{{ * import org.apache.spark.sql.expressions.Window @@ -174,28 +191,75 @@ class WindowSpec private[sql]( */ // Note: when updating the doc for this method, also update Window.rangeBetween. def rangeBetween(start: Long, end: Long): WindowSpec = { - between(RangeFrame, start, end) - } - - private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { val boundaryStart = start match { case 0 => CurrentRow case Long.MinValue => UnboundedPreceding - case x if x < 0 => ValuePreceding(-start.toInt) - case x if x > 0 => ValueFollowing(start.toInt) + case x => Literal(x) } val boundaryEnd = end match { case 0 => CurrentRow case Long.MaxValue => UnboundedFollowing - case x if x < 0 => ValuePreceding(-end.toInt) - case x if x > 0 => ValueFollowing(end.toInt) + case x => Literal(x) } new WindowSpec( partitionSpec, orderSpec, - SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + SpecifiedWindowFrame(RangeFrame, boundaryStart, boundaryEnd)) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative to the current row. For example, "lit(0)" means + * "current row", while "lit(-1)" means one off before the current row, and "lit(5)" means the + * five off after the current row. + * + * Users should use `unboundedPreceding()`, `unboundedFollowing()`, and `currentRow()` from + * [[org.apache.spark.sql.functions]] to specify special boundary values, literals are not + * transformed to [[org.apache.spark.sql.catalyst.expressions.SpecialFrameBoundary]]s. + * + * A range-based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical/date/timestamp data type. An exception can be made when the + * offset is unbounded, because no value modification is needed, in this case multiple and + * non-numerical/date/timestamp data type ORDER BY expression are allowed. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rangeBetween(currentRow(), lit(1)) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 4| + * | 1| a| 4| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * + * @param start boundary start, inclusive. The frame is unbounded if the expression is + * [[org.apache.spark.sql.catalyst.expressions.UnboundedPreceding]]. + * @param end boundary end, inclusive. The frame is unbounded if the expression is + * [[org.apache.spark.sql.catalyst.expressions.UnboundedFollowing]]. + * @since 2.3.0 + */ + def rangeBetween(start: Column, end: Column): WindowSpec = { + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(RangeFrame, start.expr, end.expr)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f07e04368389..6bbdfa3ad189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -23,15 +23,16 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import scala.util.control.NonFatal -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -777,6 +778,32 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Window function: returns the special frame boundary that represents the first row in the + * window partition. + * + * @group window_funcs + * @since 2.3.0 + */ + def unboundedPreceding(): Column = Column(UnboundedPreceding) + + /** + * Window function: returns the special frame boundary that represents the last row in the + * window partition. + * + * @group window_funcs + * @since 2.3.0 + */ + def unboundedFollowing(): Column = Column(UnboundedFollowing) + + /** + * Window function: returns the special frame boundary that represents the current row in the + * window partition. + * + * @group window_funcs + * @since 2.3.0 + */ + def currentRow(): Column = Column(CurrentRow) /** * Window function: returns the cumulative distribution of values within a window partition, @@ -1019,7 +1046,8 @@ object functions { * @since 1.5.0 */ def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) + Dataset[T](df.sparkSession, + ResolvedHint(df.logicalPlan, HintInfo(broadcast = true)))(df.exprEnc) } /** @@ -1209,7 +1237,7 @@ object functions { /** * Creates a new struct column. * If the input column is a column in a `DataFrame`, or a derived column expression - * that is named (i.e. aliased), its name would be remained as the StructField's name, + * that is named (i.e. aliased), its name would be retained as the StructField's name, * otherwise, the newly generated StructField's name would be auto generated as * `col` with a suffix `index + 1`, i.e. col1, col2, col3, ... * @@ -1265,7 +1293,7 @@ object functions { /** * Parses the expression string into the column that it represents, similar to - * DataFrame.selectExpr + * [[Dataset#selectExpr]]. * {{{ * // get the number of words of each length * df.groupBy(expr("length(word)")).count() @@ -1321,7 +1349,8 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the tangent inverse of the given value. + * Computes the tangent inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2 * * @group math_funcs * @since 1.4.0 @@ -1329,7 +1358,8 @@ object functions { def atan(e: Column): Column = withExpr { Atan(e.expr) } /** - * Computes the tangent inverse of the given column. + * Computes the tangent inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2 * * @group math_funcs * @since 1.4.0 @@ -1338,7 +1368,7 @@ object functions { /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * polar coordinates (r, theta). Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1470,7 +1500,7 @@ object functions { } /** - * Computes the cosine of the given value. + * Computes the cosine of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1565,10 +1595,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "greatest requires at least 2 arguments.") - Greatest(exprs.map(_.expr)) - } + def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) } /** * Returns the greatest value of the list of column names, skipping null values. @@ -1672,10 +1699,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "least requires at least 2 arguments.") - Least(exprs.map(_.expr)) - } + def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) } /** * Returns the least value of the list of column names, skipping null values. @@ -1943,7 +1967,7 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes the sine of the given value. + * Computes the sine of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1975,7 +1999,7 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * Computes the tangent of the given value. + * Computes the tangent of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 @@ -2117,7 +2141,7 @@ object functions { * Calculates the hash code of given columns, and returns the result as an int column. * * @group misc_funcs - * @since 2.0 + * @since 2.0.0 */ @scala.annotation.varargs def hash(cols: Column*): Column = withExpr { @@ -2291,7 +2315,8 @@ object functions { } /** - * Left-pad the string column with + * Left-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 @@ -2308,6 +2333,15 @@ object functions { */ def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } + /** + * Trim the specified character string from left end for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def ltrim(e: Column, trimString: String): Column = withExpr { + StringTrimLeft(e.expr, Literal(trimString)) + } + /** * Extract a specific group matched by a Java regex, from the specified string column. * If the regex did not match, or the specified group did not match, an empty string is returned. @@ -2349,7 +2383,8 @@ object functions { def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } /** - * Right-padded with pad to a length of len. + * Right-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 @@ -2385,7 +2420,16 @@ object functions { def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } /** - * * Return the soundex code for the specified expression. + * Trim the specified character string from right end for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def rtrim(e: Column, trimString: String): Column = withExpr { + StringTrimRight(e.expr, Literal(trimString)) + } + + /** + * Returns the soundex code for the specified expression. * * @group string_funcs * @since 1.5.0 @@ -2409,6 +2453,8 @@ object functions { * returns the slice of byte array that starts at `pos` in byte and is of length `len` * when str is Binary type * + * @note The position is not zero based, but 1 based index. + * * @group string_funcs * @since 1.5.0 */ @@ -2449,6 +2495,15 @@ object functions { */ def trim(e: Column): Column = withExpr { StringTrim(e.expr) } + /** + * Trim the specified character from both ends for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def trim(e: Column, trimString: String): Column = withExpr { + StringTrim(e.expr, Literal(trimString)) + } + /** * Converts a string column to upper case. * @@ -2491,10 +2546,10 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All - * pattern letters of `java.text.SimpleDateFormat` can be used. + * A pattern `dd.MM.yyyy` would return a string like `18.03.1993`. + * All pattern letters of `java.text.SimpleDateFormat` can be used. * - * @note Use when ever possible specialized functions like [[year]]. These benefit from a + * @note Use specialized functions like [[year]] whenever possible as they benefit from a * specialized implementation. * * @group datetime_funcs @@ -2647,7 +2702,11 @@ object functions { } /** - * Gets current Unix timestamp in seconds. + * Returns the current Unix timestamp (in seconds). + * + * @note All calls of `unix_timestamp` within the same query return the same value + * (i.e. the current timestamp is calculated at the start of query evaluation). + * * @group datetime_funcs * @since 1.5.0 */ @@ -2657,7 +2716,9 @@ object functions { /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), - * using the default timezone and the default locale, return null if fail. + * using the default timezone and the default locale. + * Returns `null` if fails. + * * @group datetime_funcs * @since 1.5.0 */ @@ -2666,22 +2727,23 @@ object functions { } /** - * Convert time string with given pattern - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix time stamp (in seconds), return null if fail. + * Converts time string with given pattern to Unix timestamp (in seconds). + * Returns `null` if fails. + * + * @see + * Customizing Formats * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) } + def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } /** - * Convert time string to a Unix timestamp (in seconds). - * Uses the pattern "yyyy-MM-dd HH:mm:ss" and will return null on failure. + * Convert time string to a Unix timestamp (in seconds) by casting rules to `TimestampType`. * @group datetime_funcs * @since 2.2.0 */ def to_timestamp(s: Column): Column = withExpr { - new ParseToTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + new ParseToTimestamp(s.expr) } /** @@ -2696,15 +2758,15 @@ object functions { } /** - * Converts the column into DateType. + * Converts the column into `DateType` by casting rules to `DateType`. * * @group datetime_funcs * @since 1.5.0 */ - def to_date(e: Column): Column = withExpr { ToDate(e.expr) } + def to_date(e: Column): Column = withExpr { new ParseToDate(e.expr) } /** - * Converts the column into a DateType with a specified format + * Converts the column into a `DateType` with a specified format * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) * return null if fail. * @@ -2729,8 +2791,9 @@ object functions { } /** - * Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp - * that corresponds to the same time of day in the given timezone. + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. * @group datetime_funcs * @since 1.5.0 */ @@ -2739,8 +2802,9 @@ object functions { } /** - * Given a timestamp, which corresponds to a certain time of day in the given timezone, returns - * another timestamp that corresponds to the same time of day in UTC. + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. * @group datetime_funcs * @since 1.5.0 */ @@ -2793,8 +2857,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window( timeColumn: Column, windowDuration: String, @@ -2847,8 +2909,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { window(timeColumn, windowDuration, slideDuration, "0 second") } @@ -2886,8 +2946,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String): Column = { window(timeColumn, windowDuration, windowDuration, "0 second") } @@ -3052,8 +3110,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s - * with the specified schema. Returns `null`, in the case of an unparseable string. + * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, @@ -3064,6 +3123,22 @@ object functions { * @since 2.1.0 */ def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = { + from_json(e, schema, options.asScala.toMap) + } + + /** + * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string as a json string, it could be a + * JSON format string or a DDL-formatted string. + * + * @group collection_funcs + * @since 2.3.0 + */ + def from_json(e: Column, schema: String, options: Map[String, String]): Column = { val dataType = try { DataType.fromJson(schema) } catch { @@ -3073,9 +3148,9 @@ object functions { } /** - * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s - * into a JSON string with the specified schema. Throws an exception, in the case of an - * unsupported type. + * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, + * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * Throws an exception, in the case of an unsupported type. * * @param e a column containing a struct or array of the structs. * @param options options to control how the struct column is converted into a json string. @@ -3089,9 +3164,9 @@ object functions { } /** - * (Java-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s - * into a JSON string with the specified schema. Throws an exception, in the case of an - * unsupported type. + * (Java-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, + * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * Throws an exception, in the case of an unsupported type. * * @param e a column containing a struct or array of the structs. * @param options options to control how the struct column is converted into a json string. @@ -3104,8 +3179,9 @@ object functions { to_json(e, options.asScala.toMap) /** - * Converts a column containing a `StructType` or `ArrayType` of `StructType`s into a JSON string - * with the specified schema. Throws an exception, in the case of an unsupported type. + * Converts a column containing a `StructType`, `ArrayType` of `StructType`s, + * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * Throws an exception, in the case of an unsupported type. * * @param e a column containing a struct or array of the structs. * @@ -3141,6 +3217,20 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns an unordered array containing the keys of the map. + * @group collection_funcs + * @since 2.3.0 + */ + def map_keys(e: Column): Column = withExpr { MapKeys(e.expr) } + + /** + * Returns an unordered array containing the values of the map. + * @group collection_funcs + * @since 2.3.0 + */ + def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// @@ -3154,157 +3244,207 @@ object functions { val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" /** - * Defines a user-defined function of ${x} arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of ${x} arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try($inputTypes).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() }""") } */ + /** - * Defines a user-defined function of 0 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 0 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 1 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 1 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 2 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 2 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 3 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 3 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 4 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 4 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 5 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 5 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 6 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 6 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 7 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 7 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 8 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 8 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 9 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 9 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } /** - * Defines a user-defined function of 10 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 10 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { + val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).toOption - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullable() } // scalastyle:on parameter.number // scalastyle:on line.size.limit /** - * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must - * specify the output data type, and there is no automatic input type coercion. + * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, + * the caller must specify the output data type, and there is no automatic input type coercion. + * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2a801d87b12e..4e756084bbdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -57,7 +57,7 @@ abstract class BaseSessionStateBuilder( type NewBuilder = (SparkSession, Option[SessionState]) => BaseSessionStateBuilder /** - * Function that produces a new instance of the SessionStateBuilder. This is used by the + * Function that produces a new instance of the `BaseSessionStateBuilder`. This is used by the * [[SessionState]]'s clone functionality. Make sure to override this when implementing your own * [[SessionStateBuilder]]. */ @@ -168,6 +168,7 @@ abstract class BaseSessionStateBuilder( override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: + PreReadCheck +: HiveOnlyCheck +: customCheckRules } @@ -208,7 +209,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. */ protected def optimizer: Optimizer = { - new SparkOptimizer(catalog, conf, experimentalMethods) { + new SparkOptimizer(catalog, experimentalMethods) { override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules } @@ -286,14 +287,14 @@ abstract class BaseSessionStateBuilder( experimentalMethods, functionRegistry, udfRegistration, - catalog, + () => catalog, sqlParser, - analyzer, - optimizer, + () => analyzer, + () => optimizer, planner, streamingQueryManager, listenerManager, - resourceLoader, + () => resourceLoader, createQueryExecution, createClone) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 0b8e53868c99..142b005850a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel /** @@ -419,6 +420,17 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { sparkSession.sharedState.cacheManager.cacheQuery(sparkSession.table(tableName), Some(tableName)) } + /** + * Caches the specified table or view with the given storage level. + * + * @group cachemgmt + * @since 2.3.0 + */ + override def cacheTable(tableName: String, storageLevel: StorageLevel): Unit = { + sparkSession.sharedState.cacheManager.cacheQuery( + sparkSession.table(tableName), Some(tableName), storageLevel) + } + /** * Removes the specified table or view from the in-memory cache. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 1b341a12fc60..accbea41b960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -42,14 +42,17 @@ import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListen * @param experimentalMethods Interface to add custom planning strategies and optimizers. * @param functionRegistry Internal catalog for managing functions registered by the user. * @param udfRegistration Interface exposed to the user for registering user-defined functions. - * @param catalog Internal catalog for managing table and database states. + * @param catalogBuilder a function to create an internal catalog for managing table and database + * states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. - * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. - * @param optimizer Logical query plan optimizer. + * @param analyzerBuilder A function to create the logical query plan analyzer for resolving + * unresolved attributes and relations. + * @param optimizerBuilder a function to create the logical query plan optimizer. * @param planner Planner that converts optimized logical plans to physical plans. * @param streamingQueryManager Interface to start and stop streaming queries. * @param listenerManager Interface to register custom [[QueryExecutionListener]]s. - * @param resourceLoader Session shared resource loader to load JARs, files, etc. + * @param resourceLoaderBuilder a function to create a session shared resource loader to load JARs, + * files, etc. * @param createQueryExecution Function used to create QueryExecution objects. * @param createClone Function used to create clones of the session state. */ @@ -59,17 +62,26 @@ private[sql] class SessionState( val experimentalMethods: ExperimentalMethods, val functionRegistry: FunctionRegistry, val udfRegistration: UDFRegistration, - val catalog: SessionCatalog, + catalogBuilder: () => SessionCatalog, val sqlParser: ParserInterface, - val analyzer: Analyzer, - val optimizer: Optimizer, + analyzerBuilder: () => Analyzer, + optimizerBuilder: () => Optimizer, val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, val listenerManager: ExecutionListenerManager, - val resourceLoader: SessionResourceLoader, + resourceLoaderBuilder: () => SessionResourceLoader, createQueryExecution: LogicalPlan => QueryExecution, createClone: (SparkSession, SessionState) => SessionState) { + // The following fields are lazy to avoid creating the Hive client when creating SessionState. + lazy val catalog: SessionCatalog = catalogBuilder() + + lazy val analyzer: Analyzer = analyzerBuilder() + + lazy val optimizer: Optimizer = optimizerBuilder() + + lazy val resourceLoader: SessionResourceLoader = resourceLoaderBuilder() + def newHadoopConf(): Configuration = SessionState.newHadoopConf( sharedState.sparkContext.hadoopConfiguration, conf) @@ -109,7 +121,7 @@ private[sql] object SessionState { } /** - * Concrete implementation of a [[SessionStateBuilder]]. + * Concrete implementation of a [[BaseSessionStateBuilder]]. */ @Experimental @InterfaceStability.Unstable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index a93b70114607..ad9db308b262 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.CacheManager @@ -90,38 +91,38 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = - SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( + lazy val externalCatalog: ExternalCatalog = { + val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, sparkContext.hadoopConfiguration) - // Create the default database if it doesn't exist. - { val defaultDbDefinition = CatalogDatabase( SessionCatalog.DEFAULT_DATABASE, "default database", CatalogUtils.stringToURI(warehousePath), Map()) - // Initialize default database if it doesn't exist + // Create default database if it doesn't exist if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) { // There may be another Spark application creating default database at the same time, here we // set `ignoreIfExists = true` to avoid `DatabaseAlreadyExists` exception. externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } - } - // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { - override def onEvent(event: ExternalCatalogEvent): Unit = { - sparkContext.listenerBus.post(event) - } - }) + // Make sure we propagate external catalog events to the spark listener bus + externalCatalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + sparkContext.listenerBus.post(event) + } + }) + + externalCatalog + } /** * A manager for global temporary views. */ - val globalTempViewManager: GlobalTempViewManager = { + lazy val globalTempViewManager: GlobalTempViewManager = { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. @@ -148,7 +149,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { if (SparkSession.sqlListener.get() == null) { val listener = new SQLListener(sc.conf) if (SparkSession.sqlListener.compareAndSet(null, listener)) { - sc.addSparkListener(listener) + sc.listenerBus.addToStatusQueue(listener) sc.ui.foreach(new SQLTab(listener, _)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 467d8d62d1b7..1419d69f983a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -41,4 +41,14 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect override def getJDBCType(dt: DataType): Option[JdbcType] = { dialects.flatMap(_.getJDBCType(dt)).headOption } + + override def isCascadingTruncateTable(): Option[Boolean] = { + // If any dialect claims cascading truncate, this dialect is also cascading truncate. + // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. + dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) match { + case Some(true) => Some(true) + case _ if dialects.exists(_.isCascadingTruncateTable().isEmpty) => None + case _ => Some(false) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 190463df0d92..d160ad82888a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,15 +17,34 @@ package org.apache.spark.sql.jdbc -import org.apache.spark.sql.types.{BooleanType, DataType, StringType} +import java.sql.Types + +import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + override def getCatalystType( + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = sqlType match { + case Types.REAL => Option(FloatType) + case Types.OTHER => + typeName match { + case "DECFLOAT" => Option(DecimalType(38, 18)) + case "XML" => Option(StringType) + case t if (t.startsWith("TIMESTAMP")) => Option(TimestampType) // TIMESTAMP WITH TIMEZONE + case _ => None + } + case _ => None + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case ShortType | ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e328b86437d6..7c38ed68c041 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.jdbc -import java.sql.Connection +import java.sql.{Connection, Date, Timestamp} + +import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.sql.types._ @@ -123,6 +125,29 @@ abstract class JdbcDialect extends Serializable { def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Escape special characters in SQL string literals. + * @param value The string to be escaped. + * @return Escaped string. + */ + @Since("2.3.0") + protected[jdbc] def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + + /** + * Converts value to SQL expression. + * @param value The value to be converted. + * @return Converted value. + */ + @Since("2.3.0") + def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. @@ -174,6 +199,7 @@ object JdbcDialects { registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) registerDialect(OracleDialect) + registerDialect(TeradataDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index f541996b651e..3b44c1de93a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Date, Timestamp, Types} import org.apache.spark.sql.types._ @@ -43,10 +43,6 @@ private case object OracleDialect extends JdbcDialect { // Not sure if there is a more robust way to identify the field as a float (or other // numeric types that do not specify a scale. case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case 1 => Option(BooleanType) - case 3 | 5 | 10 => Option(IntegerType) - case 19 if scale == 0L => Option(LongType) - case 19 if scale == 4L => Option(FloatType) case _ => None } } else { @@ -68,5 +64,18 @@ private case object OracleDialect extends JdbcDialect { case _ => None } + override def compileValue(value: Any): Any = value match { + // The JDBC drivers support date literals in SQL statements written in the + // format: {d 'yyyy-mm-dd'} and timestamp literals in SQL statements written + // in the format: {ts 'yyyy-mm-dd hh:mm:ss.f...'}. For details, see + // 'Oracle Database JDBC Developer’s Guide and Reference, 11g Release 1 (11.1)' + // Appendix A Reference Information. + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "{ts '" + timestampValue + "'}" + case dateValue: Date => "{d '" + dateValue + "'}" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala new file mode 100644 index 000000000000..5749b791fca2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object TeradataDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = { url.startsWith("jdbc:teradata") } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) + case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index ff8b15b3ff3f..6057a795c8bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -91,7 +91,7 @@ trait RelationProvider { * * The difference between a [[RelationProvider]] and a [[SchemaRelationProvider]] is that * users need to provide a schema when using a [[SchemaRelationProvider]]. - * A relation provider can inherits both [[RelationProvider]] and [[SchemaRelationProvider]] + * A relation provider can inherit both [[RelationProvider]] and [[SchemaRelationProvider]] * if it can support both schema inference and user-specified schemas. * * @since 1.3.0 @@ -163,16 +163,13 @@ trait StreamSinkProvider { @InterfaceStability.Stable trait CreatableRelationProvider { /** - * Save the DataFrame to the destination and return a relation with the given parameters based on - * the contents of the given DataFrame. The mode specifies the expected behavior of createRelation - * when data already exists. - * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. - * Append mode means that when saving a DataFrame to a data source, if data already exists, - * contents of the DataFrame are expected to be appended to existing data. - * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, - * existing data is expected to be overwritten by the contents of the DataFrame. - * ErrorIfExists mode means that when saving a DataFrame to a data source, - * if data already exists, an exception is expected to be thrown. + * Saves a DataFrame to a destination (using data source-specific parameters) + * + * @param sqlContext SQLContext + * @param mode specifies what happens when the destination already exists + * @param parameters data source-specific parameters + * @param data DataFrame to save (i.e. the rows after executing the query) + * @return Relation with a known schema * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 746b2a94f102..a42e28053a96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils @@ -35,7 +35,6 @@ import org.apache.spark.sql.types.StructType * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { /** @@ -60,6 +59,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo this } + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can + * infer the input schema automatically from data. By specifying the schema here, the underlying + * data source can skip the schema inference step, and thus speed up data loading. + * + * @since 2.3.0 + */ + def schema(schemaString: String): DataStreamReader = { + this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString)) + this + } + /** * Adds an input option for the underlying data source. * @@ -164,7 +175,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads a JSON file stream and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `wholeFile` option to true. + * default. For JSON (one record per file), set the `multiLine` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -184,6 +195,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * (e.g. 00012) *
  • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all * character using backslash quoting mechanism
  • + *
  • `allowUnquotedControlChars` (default `false`): allows JSON Strings to contain unquoted + * control characters (ASCII characters with value less than 32, including tab and line feed + * characters) or not.
  • *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      @@ -206,7 +220,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • - *
    • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
    • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
    • *
    * @@ -277,7 +291,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0d2611f9bbcc..14e7df672cc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils @@ -29,13 +29,11 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} /** - * :: Experimental :: * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, * key-value stores, etc). Use `Dataset.writeStream` to access this. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index c659ac7fcf3d..04a956b70b02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -212,7 +212,7 @@ trait GroupState[S] extends LogicalGroupState[S] { @throws[IllegalArgumentException]("when updating with null") def update(newState: S): Unit - /** Remove this state. Note that this resets any timeout configuration as well. */ + /** Remove this state. */ def remove(): Unit /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index 9ba1fc01cbd3..a033575d3d38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -23,11 +23,10 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.unsafe.types.CalendarInterval /** - * :: Experimental :: * A trigger that runs a query periodically based on the processing time. If `interval` is 0, * the query will run as fast as possible. * @@ -49,7 +48,6 @@ import org.apache.spark.unsafe.types.CalendarInterval * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { @@ -57,12 +55,10 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { } /** - * :: Experimental :: * Used to create [[ProcessingTime]] triggers for [[StreamingQuery]]s. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 12a1bb1db577..f2dfbe42260d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -19,16 +19,14 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.SparkSession /** - * :: Experimental :: * A handle to a query that is executing continuously in the background as new data arrives. * All these methods are thread-safe. * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving trait StreamingQuery { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 234a1166a195..03aeb14de502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception * that caused the failure. * @param message Message of this exception @@ -29,7 +28,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} * @param endOffset Ending offset in json of the range of data in exception occurred * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryException private[sql]( private val queryDebugString: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index c376913516ef..6aa82b89ede8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.scheduler.SparkListenerEvent /** - * :: Experimental :: * Interface for listening to events related to [[StreamingQuery StreamingQueries]]. * @note The methods are not thread-safe as they may be called from different threads. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving abstract class StreamingQueryListener { @@ -66,32 +64,26 @@ abstract class StreamingQueryListener { /** - * :: Experimental :: * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving object StreamingQueryListener { /** - * :: Experimental :: * Base type of [[StreamingQueryListener]] events * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving trait Event extends SparkListenerEvent /** - * :: Experimental :: * Event representing the start of a query * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. * @param name User-specified name of the query, null if not specified. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryStartedEvent private[sql]( val id: UUID, @@ -99,17 +91,14 @@ object StreamingQueryListener { val name: String) extends Event /** - * :: Experimental :: * Event representing any progress updates in a query. * @param progress The query progress updates. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event /** - * :: Experimental :: * Event representing that termination of a query. * * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. @@ -118,7 +107,6 @@ object StreamingQueryListener { * with an exception. Otherwise, it will be `None`. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryTerminatedEvent private[sql]( val id: UUID, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7810d9f6e964..48b0ea20e5da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker @@ -34,12 +34,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{Clock, SystemClock, Utils} /** - * :: Experimental :: - * A class to manage all the [[StreamingQuery]] active on a `SparkSession`. + * A class to manage all the [[StreamingQuery]] active in a `SparkSession`. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { @@ -334,5 +332,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } awaitTerminationLock.notifyAll() } + stateStoreCoordinator.deactivateInstances(terminatedQuery.runId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 687b1267825f..a0c9bcc8929e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -22,10 +22,9 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Reports information about the instantaneous status of a streaming query. * * @param message A human readable description of what the stream is currently doing. @@ -35,7 +34,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} * * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryStatus protected[sql]( val message: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 35fe6b8605fa..cedc1dce4a70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -29,17 +29,17 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. */ -@Experimental @InterfaceStability.Evolving class StateOperatorProgress private[sql]( val numRowsTotal: Long, - val numRowsUpdated: Long) extends Serializable { + val numRowsUpdated: Long, + val memoryUsedBytes: Long + ) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -47,14 +47,19 @@ class StateOperatorProgress private[sql]( /** The pretty (i.e. indented) JSON representation of this progress. */ def prettyJson: String = pretty(render(jsonValue)) + private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress = + new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes) + private[sql] def jsonValue: JValue = { ("numRowsTotal" -> JInt(numRowsTotal)) ~ - ("numRowsUpdated" -> JInt(numRowsUpdated)) + ("numRowsUpdated" -> JInt(numRowsUpdated)) ~ + ("memoryUsedBytes" -> JInt(memoryUsedBytes)) } + + override def toString: String = prettyJson } /** - * :: Experimental :: * Information about progress made in the execution of a [[StreamingQuery]] during * a trigger. Each event relates to processing done for a single trigger of the streaming * query. Events are emitted even when no new data is available to be processed. @@ -80,7 +85,6 @@ class StateOperatorProgress private[sql]( * @param sources detailed statistics on data being read from each of the streaming sources. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryProgress private[sql]( val id: UUID, @@ -127,6 +131,7 @@ class StreamingQueryProgress private[sql]( ("runId" -> JString(runId.toString)) ~ ("name" -> JString(name)) ~ ("timestamp" -> JString(timestamp)) ~ + ("batchId" -> JInt(batchId)) ~ ("numInputRows" -> JInt(numInputRows)) ~ ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~ @@ -139,7 +144,6 @@ class StreamingQueryProgress private[sql]( } /** - * :: Experimental :: * Information about progress made for a source in the execution of a [[StreamingQuery]] * during a trigger. See [[StreamingQueryProgress]] for more information. * @@ -152,7 +156,6 @@ class StreamingQueryProgress private[sql]( * Spark. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class SourceProgress protected[sql]( val description: String, @@ -191,14 +194,12 @@ class SourceProgress protected[sql]( } /** - * :: Experimental :: * Information about progress made for a sink in the execution of a [[StreamingQuery]] * during a trigger. See [[StreamingQueryProgress]] for more information. * * @param description Description of the source corresponding to this status. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class SinkProgress protected[sql]( val description: String) extends Serializable { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3ba37addfc8b..c132cab1b38c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1283,6 +1283,80 @@ public void test() { ds.collectAsList(); } + public enum MyEnum { + A("www.elgoog.com"), + B("www.google.com"); + + private String url; + + MyEnum(String url) { + this.url = url; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + } + + public static class BeanWithEnum { + MyEnum enumField; + String regularField; + + public String getRegularField() { + return regularField; + } + + public void setRegularField(String regularField) { + this.regularField = regularField; + } + + public MyEnum getEnumField() { + return enumField; + } + + public void setEnumField(MyEnum field) { + this.enumField = field; + } + + public BeanWithEnum(MyEnum enumField, String regularField) { + this.enumField = enumField; + this.regularField = regularField; + } + + public BeanWithEnum() { + } + + public String toString() { + return "BeanWithEnum(" + enumField + ", " + regularField + ")"; + } + + public int hashCode() { + return Objects.hashCode(enumField, regularField); + } + + public boolean equals(Object other) { + if (other instanceof BeanWithEnum) { + BeanWithEnum beanWithEnum = (BeanWithEnum) other; + return beanWithEnum.regularField.equals(regularField) + && beanWithEnum.enumField.equals(enumField); + } + return false; + } + } + + @Test + public void testBeanWithEnum() { + List data = Arrays.asList(new BeanWithEnum(MyEnum.A, "mira avenue"), + new BeanWithEnum(MyEnum.B, "flower boulevard")); + Encoder encoder = Encoders.bean(BeanWithEnum.class); + Dataset ds = spark.createDataset(data, encoder); + Assert.assertEquals(ds.collectAsList(), data); + } + public static class EmptyBean implements Serializable {} @Test @@ -1399,4 +1473,65 @@ public void testSerializeNull() { ds1.map((MapFunction) b -> b, encoder); Assert.assertEquals(beans, ds2.collectAsList()); } + + @Test + public void testSpecificLists() { + SpecificListsBean bean = new SpecificListsBean(); + ArrayList arrayList = new ArrayList<>(); + arrayList.add(1); + bean.setArrayList(arrayList); + LinkedList linkedList = new LinkedList<>(); + linkedList.add(1); + bean.setLinkedList(linkedList); + bean.setList(Collections.singletonList(1)); + List beans = Collections.singletonList(bean); + Dataset dataset = + spark.createDataset(beans, Encoders.bean(SpecificListsBean.class)); + Assert.assertEquals(beans, dataset.collectAsList()); + } + + public static class SpecificListsBean implements Serializable { + private ArrayList arrayList; + private LinkedList linkedList; + private List list; + + public ArrayList getArrayList() { + return arrayList; + } + + public void setArrayList(ArrayList arrayList) { + this.arrayList = arrayList; + } + + public LinkedList getLinkedList() { + return linkedList; + } + + public void setLinkedList(LinkedList linkedList) { + this.linkedList = linkedList; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SpecificListsBean that = (SpecificListsBean) o; + return Objects.equal(arrayList, that.arrayList) && + Objects.equal(linkedList, that.linkedList) && + Objects.equal(list, that.list); + } + + @Override + public int hashCode() { + return Objects.hashCode(arrayList, linkedList, list); + } + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java new file mode 100644 index 000000000000..ddbaa45a483c --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + + +public class JavaUDAFSuite { + + private transient SparkSession spark; + + @Before + public void setUp() { + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); + Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); + Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); + } + +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 250fa674d8ec..5bf188882618 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -25,6 +25,7 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF2; @@ -105,4 +106,19 @@ public void udf4Test() { } Assert.assertEquals(55, sum); } + + @SuppressWarnings("unchecked") + @Test(expected = AnalysisException.class) + public void udf5Test() { + spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); + List results = spark.sql("SELECT inc(1, 5)").collectAsList(); + } + + @SuppressWarnings("unchecked") + @Test + public void udf6Test() { + spark.udf().register("returnOne", () -> 1, DataTypes.IntegerType); + Row result = spark.sql("SELECT returnOne()").head(); + Assert.assertEquals(1, result.getInt(0)); + } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java similarity index 99% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java rename to sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java index ae0c097c362a..447a71d284fb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.aggregate; +package test.org.apache.spark.sql; import java.util.ArrayList; import java.util.List; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java similarity index 98% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java rename to sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java index d17fb3e5194f..93d20330c717 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.aggregate; +package test.org.apache.spark.sql; import java.util.ArrayList; import java.util.List; +import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; /** * An example {@link UserDefinedAggregateFunction} to calculate the sum of a diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java new file mode 100644 index 000000000000..7aacf0346d2f --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.*; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsPushDownRequiredColumns, + SupportsPushDownFilters { + + private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + private Filter[] filters = new Filter[0]; + + @Override + public StructType readSchema() { + return requiredSchema; + } + + @Override + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + this.filters = filters; + return new Filter[0]; + } + + @Override + public List> createReadTasks() { + List> res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaAdvancedReadTask(0, 5, requiredSchema)); + res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + } else if (lowerBound < 4) { + res.add(new JavaAdvancedReadTask(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + } else if (lowerBound < 9) { + res.add(new JavaAdvancedReadTask(lowerBound + 1, 10, requiredSchema)); + } + + return res; + } + } + + static class JavaAdvancedReadTask implements ReadTask, DataReader { + private int start; + private int end; + private StructType requiredSchema; + + JavaAdvancedReadTask(int start, int end, StructType requiredSchema) { + this.start = start; + this.end = end; + this.requiredSchema = requiredSchema; + } + + @Override + public DataReader createReader() { + return new JavaAdvancedReadTask(start - 1, end, requiredSchema); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public Row get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = start; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -start; + } + } + return new GenericRow(values); + } + + @Override + public void close() throws IOException { + + } + } + + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java new file mode 100644 index 000000000000..a174bd8092cb --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.types.StructType; + +public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { + + class Reader implements DataSourceV2Reader { + private final StructType schema; + + Reader(StructType schema) { + this.schema = schema; + } + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Collections.emptyList(); + } + } + + @Override + public DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options) { + return new Reader(schema); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java new file mode 100644 index 000000000000..08469f14c257 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.types.StructType; + +public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Arrays.asList( + new JavaSimpleReadTask(0, 5), + new JavaSimpleReadTask(5, 10)); + } + } + + static class JavaSimpleReadTask implements ReadTask, DataReader { + private int start; + private int end; + + JavaSimpleReadTask(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader createReader() { + return new JavaSimpleReadTask(start - 1, end); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public Row get() { + return new GenericRow(new Object[] {start, -start}); + } + + @Override + public void close() throws IOException { + + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java new file mode 100644 index 000000000000..9efe7c791a93 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsScanUnsafeRow { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createUnsafeRowReadTasks() { + return java.util.Arrays.asList( + new JavaUnsafeRowReadTask(0, 5), + new JavaUnsafeRowReadTask(5, 10)); + } + } + + static class JavaUnsafeRowReadTask implements ReadTask, DataReader { + private int start; + private int end; + private UnsafeRow row; + + JavaUnsafeRowReadTask(int start, int end) { + this.start = start; + this.end = end; + this.row = new UnsafeRow(2); + row.pointTo(new byte[8 * 3], 8 * 3); + } + + @Override + public DataReader createReader() { + return new JavaUnsafeRowReadTask(start - 1, end); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public UnsafeRow get() { + row.setInt(0, start); + row.setInt(1, -start); + return row; + } + + @Override + public void close() throws IOException { + + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index cfd7889b4ac2..c6973bf41d34 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,7 @@ org.apache.spark.sql.sources.FakeSourceOne org.apache.spark.sql.sources.FakeSourceTwo org.apache.spark.sql.sources.FakeSourceThree +org.apache.spark.sql.sources.FakeSourceFour +org.apache.fakesource.FakeExternalSourceOne +org.apache.fakesource.FakeExternalSourceTwo +org.apache.fakesource.FakeExternalSourceThree diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql deleted file mode 100644 index f62b10ca0037..000000000000 --- a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql +++ /dev/null @@ -1,34 +0,0 @@ - --- unary minus and plus -select -100; -select +230; -select -5.2; -select +6.8e0; -select -key, +key from testdata where key = 2; -select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1; -select -max(key), +max(key) from testdata; -select - (-10); -select + (-key) from testdata where key = 32; -select - (+max(key)) from testdata; -select - - 3; -select - + 20; -select + + 100; -select - - max(key) from testdata; -select + - key from testdata where key = 33; - --- div -select 5 / 2; -select 5 / 0; -select 5 / null; -select null / 5; -select 5 div 2; -select 5 div 0; -select 5 div null; -select null div 5; - --- other arithmetics -select 1 + 2; -select 1 - 2; -select 2 * 5; -select 5 % 3; -select pmod(-7, 3); diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql index 5fae571945e4..629df59cff8b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql @@ -40,4 +40,6 @@ SELECT CAST('-9223372036854775809' AS long); SELECT CAST('9223372036854775807' AS long); SELECT CAST('9223372036854775808' AS long); +DESC FUNCTION boolean; +DESC FUNCTION EXTENDED boolean; -- TODO: migrate all cast tests here. diff --git a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql new file mode 100644 index 000000000000..3e2447723e57 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql @@ -0,0 +1,3 @@ +-- binary type +select x'00' < x'0f'; +select x'00' < x'ff'; diff --git a/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql b/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql index aa7312437487..b64197e2bc70 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql @@ -32,4 +32,5 @@ create temporary view D(d, vd) as select * from nt1; -- Allowed since cross join with C is explicit select * from ((A join B on (a = b)) cross join C) join D on (a = d); - +-- Cross joins with non-equal predicates +SELECT * FROM nt1 CROSS JOIN nt2 ON (nt1.k > nt2.k); diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 3fd1c37e7179..616b6caee3f2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -2,3 +2,9 @@ -- [SPARK-16836] current_date and current_timestamp literals select current_date = current_date(), current_timestamp = current_timestamp(); + +select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd'); + +select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd'); + +select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql new file mode 100644 index 000000000000..f4239da90627 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql @@ -0,0 +1,34 @@ +CREATE TABLE t (key STRING, value STRING, ds STRING, hr INT) USING parquet + PARTITIONED BY (ds, hr); + +INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=10) +VALUES ('k1', 100), ('k2', 200), ('k3', 300); + +INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=11) +VALUES ('k1', 101), ('k2', 201), ('k3', 301), ('k4', 401); + +INSERT INTO TABLE t PARTITION (ds='2017-09-01', hr=5) +VALUES ('k1', 102), ('k2', 202); + +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10); + +-- Collect stats for a single partition +ANALYZE TABLE t PARTITION (ds='2017-08-01', hr=10) COMPUTE STATISTICS; + +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10); + +-- Collect stats for 2 partitions +ANALYZE TABLE t PARTITION (ds='2017-08-01') COMPUTE STATISTICS; + +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10); +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11); + +-- Collect stats for all partitions +ANALYZE TABLE t PARTITION (ds, hr) COMPUTE STATISTICS; + +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10); +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11); +DESC EXTENDED t PARTITION (ds='2017-09-01', hr=5); + +-- DROP TEST TABLES/VIEWS +DROP TABLE t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql new file mode 100644 index 000000000000..69bff6656c43 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql @@ -0,0 +1,29 @@ +CREATE TABLE table_with_comment (a STRING, b INT, c STRING, d STRING) USING parquet COMMENT 'added'; + +DESC FORMATTED table_with_comment; + +-- ALTER TABLE BY MODIFYING COMMENT +ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment", "type"= "parquet"); + +DESC FORMATTED table_with_comment; + +-- DROP TEST TABLE +DROP TABLE table_with_comment; + +-- CREATE TABLE WITHOUT COMMENT +CREATE TABLE table_comment (a STRING, b INT) USING parquet; + +DESC FORMATTED table_comment; + +-- ALTER TABLE BY ADDING COMMENT +ALTER TABLE table_comment SET TBLPROPERTIES(comment = "added comment"); + +DESC formatted table_comment; + +-- ALTER UNSET PROPERTIES COMMENT +ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment'); + +DESC FORMATTED table_comment; + +-- DROP TEST TABLE +DROP TABLE table_comment; diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql new file mode 100644 index 000000000000..a6ddcd999bf9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql @@ -0,0 +1,41 @@ +-- Test temp table +CREATE TEMPORARY VIEW desc_col_temp_view (key int COMMENT 'column_comment') USING PARQUET; + +DESC desc_col_temp_view key; + +DESC EXTENDED desc_col_temp_view key; + +DESC FORMATTED desc_col_temp_view key; + +-- Describe a column with qualified name +DESC FORMATTED desc_col_temp_view desc_col_temp_view.key; + +-- Describe a non-existent column +DESC desc_col_temp_view key1; + +-- Test persistent table +CREATE TABLE desc_col_table (key int COMMENT 'column_comment') USING PARQUET; + +ANALYZE TABLE desc_col_table COMPUTE STATISTICS FOR COLUMNS key; + +DESC desc_col_table key; + +DESC EXTENDED desc_col_table key; + +DESC FORMATTED desc_col_table key; + +-- Test complex columns +CREATE TABLE desc_complex_col_table (`a.b` int, col struct) USING PARQUET; + +DESC FORMATTED desc_complex_col_table `a.b`; + +DESC FORMATTED desc_complex_col_table col; + +-- Describe a nested column +DESC FORMATTED desc_complex_col_table col.x; + +DROP VIEW desc_col_temp_view; + +DROP TABLE desc_col_table; + +DROP TABLE desc_complex_col_table; diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 6de4cf0d5afa..f26d5efec076 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -1,6 +1,8 @@ CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + OPTIONS (a '1', b '2') PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS - COMMENT 'table_comment'; + COMMENT 'table_comment' + TBLPROPERTIES (t 'test'); CREATE TEMPORARY VIEW temp_v AS SELECT * FROM t; @@ -13,6 +15,8 @@ CREATE TEMPORARY VIEW temp_Data_Source_View CREATE VIEW v AS SELECT * FROM t; +ALTER TABLE t SET TBLPROPERTIES (e = '3'); + ALTER TABLE t ADD PARTITION (c='Us', d=1); DESCRIBE t; @@ -25,6 +29,14 @@ DESC FORMATTED t; DESC EXTENDED t; +ALTER TABLE t UNSET TBLPROPERTIES (e); + +DESC EXTENDED t; + +ALTER TABLE t UNSET TBLPROPERTIES (comment); + +DESC EXTENDED t; + DESC t PARTITION (c='Us', d=1); DESC EXTENDED t PARTITION (c='Us', d=1); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index f8135389a9e5..8aff4cb52419 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -54,4 +54,9 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year; SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; \ No newline at end of file +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; + +-- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS +SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2); +SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b); +SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 6566338f3d4a..928f766b4add 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -52,7 +52,7 @@ select count(a), a from (select 1 as a) tmp group by 2 having a > 0; -- mixed cases: group-by ordinals and aliases select a, a AS k, count(b) from data group by k, 1; --- turn of group by ordinal +-- turn off group by ordinal set spark.sql.groupByOrdinal=false; -- can now group by negative literal diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index a7994f3beaff..1e1384549a41 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -53,3 +53,10 @@ set spark.sql.groupByAliases=false; -- Check analysis exceptions SELECT a AS k, COUNT(b) FROM testData GROUP BY k; + +-- Aggregate with empty input and non-empty GroupBy expressions. +SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a; + +-- Aggregate with empty input and empty GroupBy expressions. +SELECT COUNT(1) FROM testData WHERE false; +SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index b3ec956cd178..41d316444ed6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -49,3 +49,6 @@ select * from values ("one", count(1)), ("two", 2) as data(a, b); -- string to timestamp select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); + +-- cross-join inline tables +EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index b3cc2cea51d4..fea069eac4d4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -4,6 +4,11 @@ describe function extended to_json; select to_json(named_struct('a', 1, 'b', 2)); select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); select to_json(array(named_struct('a', 1, 'b', 2))); +select to_json(map(named_struct('a', 1, 'b', 2), named_struct('a', 1, 'b', 2))); +select to_json(map('a', named_struct('a', 1, 'b', 2))); +select to_json(map('a', 1)); +select to_json(array(map('a',1))); +select to_json(array(map('a',1), map('b',2))); -- Check if errors handled select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)); @@ -20,3 +25,9 @@ select from_json('{"a":1}', 'a InvalidType'); select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')); select from_json('{"a":1}', 'a INT', map('mode', 1)); select from_json(); +-- json_tuple +SELECT json_tuple('{"a" : 1, "b" : 2}', CAST(NULL AS STRING), 'b', CAST(NULL AS STRING), 'a'); +CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a'); +SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; +-- Clean up +DROP VIEW IF EXISTS jsonTable; diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index 2ea35f7f3a5c..f21912a04271 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,23 +1,27 @@ -- limit on various data types -select * from testdata limit 2; -select * from arraydata limit 2; -select * from mapdata limit 2; +SELECT * FROM testdata LIMIT 2; +SELECT * FROM arraydata LIMIT 2; +SELECT * FROM mapdata LIMIT 2; -- foldable non-literal in limit -select * from testdata limit 2 + 1; +SELECT * FROM testdata LIMIT 2 + 1; -select * from testdata limit CAST(1 AS int); +SELECT * FROM testdata LIMIT CAST(1 AS int); -- limit must be non-negative -select * from testdata limit -1; +SELECT * FROM testdata LIMIT -1; +SELECT * FROM testData TABLESAMPLE (-1 ROWS); -- limit must be foldable -select * from testdata limit key > 3; +SELECT * FROM testdata LIMIT key > 3; -- limit must be integer -select * from testdata limit true; -select * from testdata limit 'a'; +SELECT * FROM testdata LIMIT true; +SELECT * FROM testdata LIMIT 'a'; -- limit within a subquery -select * from (select * from range(10) limit 5) where id > 3; +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3; + +-- limit ALL +SELECT * FROM testdata WHERE key < 3 LIMIT ALL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql new file mode 100644 index 000000000000..15d981985c55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -0,0 +1,98 @@ + +-- unary minus and plus +select -100; +select +230; +select -5.2; +select +6.8e0; +select -key, +key from testdata where key = 2; +select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1; +select -max(key), +max(key) from testdata; +select - (-10); +select + (-key) from testdata where key = 32; +select - (+max(key)) from testdata; +select - - 3; +select - + 20; +select + + 100; +select - - max(key) from testdata; +select + - key from testdata where key = 33; + +-- div +select 5 / 2; +select 5 / 0; +select 5 / null; +select null / 5; +select 5 div 2; +select 5 div 0; +select 5 div null; +select null div 5; + +-- other arithmetics +select 1 + 2; +select 1 - 2; +select 2 * 5; +select 5 % 3; +select pmod(-7, 3); + +-- check operator precedence. +-- We follow Oracle operator precedence in the table below that lists the levels of precedence +-- among SQL operators from high to low: +------------------------------------------------------------------------------------------ +-- Operator Operation +------------------------------------------------------------------------------------------ +-- +, - identity, negation +-- *, / multiplication, division +-- +, -, || addition, subtraction, concatenation +-- =, !=, <, >, <=, >=, IS NULL, LIKE, BETWEEN, IN comparison +-- NOT exponentiation, logical negation +-- AND conjunction +-- OR disjunction +------------------------------------------------------------------------------------------ +explain select 'a' || 1 + 2; +explain select 1 - 2 || 'b'; +explain select 2 * 4 + 3 || 'b'; +explain select 3 + 1 || 'a' || 4 / 2; +explain select 1 == 1 OR 'a' || 'b' == 'ab'; +explain select 'a' || 'c' == 'ac' AND 2 == 3; + +-- math functions +select cot(1); +select cot(null); +select cot(0); +select cot(-1); + +-- ceil and ceiling +select ceiling(0); +select ceiling(1); +select ceil(1234567890123456); +select ceiling(1234567890123456); +select ceil(0.01); +select ceiling(-0.10); + +-- floor +select floor(0); +select floor(1); +select floor(1234567890123456); +select floor(0.01); +select floor(-0.10); + +-- comparison operator +select 1 > 0.00001; + +-- mod +select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null); + +-- length +select BIT_LENGTH('abc'); +select CHAR_LENGTH('abc'); +select CHARACTER_LENGTH('abc'); +select OCTET_LENGTH('abc'); + +-- abs +select abs(-3.13), abs('-2.19'); + +-- positive/negative +select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); + +-- pmod +select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null); +select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql new file mode 100644 index 000000000000..3b3d4ad64b3e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -0,0 +1,36 @@ +-- EqualTo +select 1 = 1; +select 1 = '1'; +select 1.0 = '1'; + +-- GreaterThan +select 1 > '1'; +select 2 > '1.0'; +select 2 > '2.0'; +select 2 > '2.2'; +select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52'); +select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52'; + +-- GreaterThanOrEqual +select 1 >= '1'; +select 2 >= '1.0'; +select 2 >= '2.0'; +select 2.0 >= '2.2'; +select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52'); +select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52'; + +-- LessThan +select 1 < '1'; +select 2 < '1.0'; +select 2 < '2.0'; +select 2.0 < '2.2'; +select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52'); +select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52'; + +-- LessThanOrEqual +select 1 <= '1'; +select 2 <= '1.0'; +select 2 <= '2.0'; +select 2.0 <= '2.2'; +select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); +select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; diff --git a/sql/core/src/test/resources/sql-tests/inputs/query_regex_column.sql b/sql/core/src/test/resources/sql-tests/inputs/query_regex_column.sql new file mode 100644 index 000000000000..ad96754826a4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/query_regex_column.sql @@ -0,0 +1,52 @@ +set spark.sql.parser.quotedRegexColumnNames=false; + +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, "1", "11"), (2, "2", "22"), (3, "3", "33"), (4, "4", "44"), (5, "5", "55"), (6, "6", "66") +AS testData(key, value1, value2); + +CREATE OR REPLACE TEMPORARY VIEW testData2 AS SELECT * FROM VALUES +(1, 1, 1, 2), (1, 2, 1, 2), (2, 1, 2, 3), (2, 2, 2, 3), (3, 1, 3, 4), (3, 2, 3, 4) +AS testData2(A, B, c, d); + +-- AnalysisException +SELECT `(a)?+.+` FROM testData2 WHERE a = 1; +SELECT t.`(a)?+.+` FROM testData2 t WHERE a = 1; +SELECT `(a|b)` FROM testData2 WHERE a = 2; +SELECT `(a|b)?+.+` FROM testData2 WHERE a = 2; +SELECT SUM(`(a|b)?+.+`) FROM testData2; +SELECT SUM(`(a)`) FROM testData2; + +set spark.sql.parser.quotedRegexColumnNames=true; + +-- Regex columns +SELECT `(a)?+.+` FROM testData2 WHERE a = 1; +SELECT `(A)?+.+` FROM testData2 WHERE a = 1; +SELECT t.`(a)?+.+` FROM testData2 t WHERE a = 1; +SELECT t.`(A)?+.+` FROM testData2 t WHERE a = 1; +SELECT `(a|B)` FROM testData2 WHERE a = 2; +SELECT `(A|b)` FROM testData2 WHERE a = 2; +SELECT `(a|B)?+.+` FROM testData2 WHERE a = 2; +SELECT `(A|b)?+.+` FROM testData2 WHERE a = 2; +SELECT `(e|f)` FROM testData2; +SELECT t.`(e|f)` FROM testData2 t; +SELECT p.`(KEY)?+.+`, b, testdata2.`(b)?+.+` FROM testData p join testData2 ON p.key = testData2.a WHERE key < 3; +SELECT p.`(key)?+.+`, b, testdata2.`(b)?+.+` FROM testData p join testData2 ON p.key = testData2.a WHERE key < 3; + +set spark.sql.caseSensitive=true; + +CREATE OR REPLACE TEMPORARY VIEW testdata3 AS SELECT * FROM VALUES +(0, 1), (1, 2), (2, 3), (3, 4) +AS testdata3(a, b); + +-- Regex columns +SELECT `(A)?+.+` FROM testdata3; +SELECT `(a)?+.+` FROM testdata3; +SELECT `(A)?+.+` FROM testdata3 WHERE a > 1; +SELECT `(a)?+.+` FROM testdata3 where `a` > 1; +SELECT SUM(`a`) FROM testdata3; +SELECT SUM(`(a)`) FROM testdata3; +SELECT SUM(`(a)?+.+`) FROM testdata3; +SELECT SUM(a) FROM testdata3 GROUP BY `a`; +-- AnalysisException +SELECT SUM(a) FROM testdata3 GROUP BY `(a)`; +SELECT SUM(a) FROM testdata3 GROUP BY `(a)?+.+`; diff --git a/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql b/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql index 1e02c2f045ea..521018e94e50 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql @@ -2,9 +2,9 @@ CREATE DATABASE showdb; USE showdb; -CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet; +CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json; CREATE TABLE showcolumn2 (price int, qty int, year int, month int) USING parquet partitioned by (year, month); -CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet; +CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json; CREATE GLOBAL TEMP VIEW showColumn4 AS SELECT 1 as col1, 'abc' as `col 5`; diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql index 2b5b692d29ef..f1461032065a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql @@ -23,3 +23,7 @@ SELECT float(1), double(1), decimal(1); SELECT date("2014-04-04"), timestamp(date("2014-04-04")); -- error handling: only one argument SELECT string(1, 2); + +-- SPARK-21555: RuntimeReplaceable used in group by +CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st); +SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value"); diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index f21981ef7b72..40d0c064f5c4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -1,3 +1,26 @@ -- Argument number exception select concat_ws(); select format_string(); + +-- A pipe operator for string concatenation +select 'a' || 'b' || 'c'; + +-- Check if catalyst combine nested `Concat`s +EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col +FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)); + +-- replace function +select replace('abc', 'b', '123'); +select replace('abc', 'b'); + +-- uuid +select length(uuid()), (uuid() <> uuid()); + +-- position +select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null); + +-- left && right +select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null); +select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); +select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); +select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/struct.sql b/sql/core/src/test/resources/sql-tests/inputs/struct.sql index e56344dc4de8..93a1238ab18c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/struct.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/struct.sql @@ -18,3 +18,10 @@ SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x; -- Prepend a column to a struct SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x; + +-- Select a column from a struct +SELECT ID, STRUCT(ST.*).C NST FROM tbl_x; +SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x; + +-- Select an alias from a struct +SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql new file mode 100644 index 000000000000..b15f4da81dd9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql @@ -0,0 +1,47 @@ +-- The test file contains negative test cases +-- of invalid queries where error messages are expected. + +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c); + +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c); + +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c); + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.03 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a); + +-- TC 01.04 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a); + diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/subquery-in-from.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/subquery-in-from.sql new file mode 100644 index 000000000000..1273b56b6344 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/subquery-in-from.sql @@ -0,0 +1,14 @@ +-- Aliased subqueries in FROM clause +SELECT * FROM (SELECT * FROM testData) AS t WHERE key = 1; + +FROM (SELECT * FROM testData WHERE key = 1) AS t SELECT *; + +-- Optional `AS` keyword +SELECT * FROM (SELECT * FROM testData) t WHERE key = 1; + +FROM (SELECT * FROM testData WHERE key = 1) t SELECT *; + +-- Disallow unaliased subqueries in FROM clause +SELECT * FROM (SELECT * FROM testData) WHERE key = 1; + +FROM (SELECT * FROM testData WHERE key = 1) SELECT *; diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql b/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql new file mode 100644 index 000000000000..4cfd5f28afda --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql @@ -0,0 +1,27 @@ +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES (1, 1), (1, 2), (2, 1) AS testData(a, b); + +-- Table column aliases in FROM clause +SELECT * FROM testData AS t(col1, col2) WHERE col1 = 1; + +SELECT * FROM testData AS t(col1, col2) WHERE col1 = 2; + +SELECT col1 AS k, SUM(col2) FROM testData AS t(col1, col2) GROUP BY k; + +-- Aliasing the wrong number of columns in the FROM clause +SELECT * FROM testData AS t(col1, col2, col3); + +SELECT * FROM testData AS t(col1); + +-- Check alias duplication +SELECT a AS col1, b AS col2 FROM testData AS t(c, d); + +-- Subquery aliases in FROM clause +SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2); + +-- Aliases for join relations in FROM clause +CREATE OR REPLACE TEMPORARY VIEW src1 AS SELECT * FROM VALUES (1, "a"), (2, "b"), (3, "c") AS src1(id, v1); + +CREATE OR REPLACE TEMPORARY VIEW src2 AS SELECT * FROM VALUES (2, 1.0), (3, 3.2), (1, 8.5) AS src2(id, v2); + +SELECT * FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d); diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index d0d2df7b243d..72cd8ca9d872 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -24,3 +24,6 @@ select * from RaNgE(2); -- Explain EXPLAIN select * from RaNgE(2); + +-- cross-join table valued functions +EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3); diff --git a/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql b/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql new file mode 100644 index 000000000000..72508f59bee2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql @@ -0,0 +1,14 @@ +-- Negative testcases for tablesample +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +-- Negative tests: negative percentage +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT); + +-- Negative tests: percentage over 100 +-- The TABLESAMPLE clause samples without replacement, so the value of PERCENT must not exceed 100 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT); + +-- reset +DROP DATABASE mydb1 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql new file mode 100644 index 000000000000..2183ba23afc3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -0,0 +1,13 @@ +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(1), (2), (3), (4) +as t1(int_col1); + +CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg'; + +SELECT default.myDoubleAvg(int_col1) as my_avg from t1; + +SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1; + +CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'; + +SELECT default.udaf1(int_col1) as udaf1 from t1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql new file mode 100644 index 000000000000..c4bea34ec4cf --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -0,0 +1,103 @@ +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(null, 1L, 1.0D, date("2017-08-01"), timestamp(1501545600), "a"), +(1, 1L, 1.0D, date("2017-08-01"), timestamp(1501545600), "a"), +(1, 2L, 2.5D, date("2017-08-02"), timestamp(1502000000), "a"), +(2, 2147483650L, 100.001D, date("2020-12-31"), timestamp(1609372800), "a"), +(1, null, 1.0D, date("2017-08-01"), timestamp(1501545600), "b"), +(2, 3L, 3.3D, date("2017-08-03"), timestamp(1503000000), "b"), +(3, 2147483650L, 100.001D, date("2020-12-31"), timestamp(1609372800), "b"), +(null, null, null, null, null, null), +(3, 1L, 1.0D, date("2017-08-01"), timestamp(1501545600), null) +AS testData(val, val_long, val_double, val_date, val_timestamp, cate); + +-- RowsBetween +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData +ORDER BY cate, val; +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val +ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long; + +-- RangeBetween +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData +ORDER BY cate, val; +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long; +SELECT val_double, cate, sum(val_double) OVER(PARTITION BY cate ORDER BY val_double +RANGE BETWEEN CURRENT ROW AND 2.5 FOLLOWING) FROM testData ORDER BY cate, val_double; +SELECT val_date, cate, max(val_date) OVER(PARTITION BY cate ORDER BY val_date +RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM testData ORDER BY cate, val_date; +SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp +RANGE BETWEEN CURRENT ROW AND interval 23 days 4 hours FOLLOWING) FROM testData +ORDER BY cate, val_timestamp; + +-- RangeBetween with reverse OrderBy +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; + +-- Invalid window frame +SELECT val, cate, count(val) OVER(PARTITION BY cate +ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val, cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY current_timestamp +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val; + + +-- Window functions +SELECT val, cate, +max(val) OVER w AS max, +min(val) OVER w AS min, +min(val) OVER w AS min, +count(val) OVER w AS count, +sum(val) OVER w AS sum, +avg(val) OVER w AS avg, +stddev(val) OVER w AS stddev, +first_value(val) OVER w AS first_value, +first_value(val, true) OVER w AS first_value_ignore_null, +first_value(val, false) OVER w AS first_value_contain_null, +last_value(val) OVER w AS last_value, +last_value(val, true) OVER w AS last_value_ignore_null, +last_value(val, false) OVER w AS last_value_contain_null, +rank() OVER w AS rank, +dense_rank() OVER w AS dense_rank, +cume_dist() OVER w AS cume_dist, +percent_rank() OVER w AS percent_rank, +ntile(2) OVER w AS ntile, +row_number() OVER w AS row_number, +var_pop(val) OVER w AS var_pop, +var_samp(val) OVER w AS var_samp, +approx_count_distinct(val) OVER w AS approx_count_distinct +FROM testData +WINDOW w AS (PARTITION BY cate ORDER BY val) +ORDER BY cate, val; + +-- Null inputs +SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val; + +-- OrderBy not specified +SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val; + +-- Over clause is empty +SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val; + +-- first_value()/last_value() over () +SELECT val, cate, +first_value(false) OVER w AS first_value, +first_value(true, true) OVER w AS first_value_ignore_null, +first_value(false, false) OVER w AS first_value_contain_null, +last_value(false) OVER w AS last_value, +last_value(true, true) OVER w AS last_value_ignore_null, +last_value(false, false) OVER w AS last_value_contain_null +FROM testData +WINDOW w AS () +ORDER BY cate, val; diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out deleted file mode 100644 index ce42c016a710..000000000000 --- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out +++ /dev/null @@ -1,226 +0,0 @@ --- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 - - --- !query 0 -select -100 --- !query 0 schema -struct<-100:int> --- !query 0 output --100 - - --- !query 1 -select +230 --- !query 1 schema -struct<230:int> --- !query 1 output -230 - - --- !query 2 -select -5.2 --- !query 2 schema -struct<-5.2:decimal(2,1)> --- !query 2 output --5.2 - - --- !query 3 -select +6.8e0 --- !query 3 schema -struct<6.8:decimal(2,1)> --- !query 3 output -6.8 - - --- !query 4 -select -key, +key from testdata where key = 2 --- !query 4 schema -struct<(- key):int,key:int> --- !query 4 output --2 2 - - --- !query 5 -select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1 --- !query 5 schema -struct<(- (key + 1)):int,((- key) + 1):int,(key + 5):int> --- !query 5 output --2 0 6 - - --- !query 6 -select -max(key), +max(key) from testdata --- !query 6 schema -struct<(- max(key)):int,max(key):int> --- !query 6 output --100 100 - - --- !query 7 -select - (-10) --- !query 7 schema -struct<(- -10):int> --- !query 7 output -10 - - --- !query 8 -select + (-key) from testdata where key = 32 --- !query 8 schema -struct<(- key):int> --- !query 8 output --32 - - --- !query 9 -select - (+max(key)) from testdata --- !query 9 schema -struct<(- max(key)):int> --- !query 9 output --100 - - --- !query 10 -select - - 3 --- !query 10 schema -struct<(- -3):int> --- !query 10 output -3 - - --- !query 11 -select - + 20 --- !query 11 schema -struct<(- 20):int> --- !query 11 output --20 - - --- !query 12 -select + + 100 --- !query 12 schema -struct<100:int> --- !query 12 output -100 - - --- !query 13 -select - - max(key) from testdata --- !query 13 schema -struct<(- (- max(key))):int> --- !query 13 output -100 - - --- !query 14 -select + - key from testdata where key = 33 --- !query 14 schema -struct<(- key):int> --- !query 14 output --33 - - --- !query 15 -select 5 / 2 --- !query 15 schema -struct<(CAST(5 AS DOUBLE) / CAST(2 AS DOUBLE)):double> --- !query 15 output -2.5 - - --- !query 16 -select 5 / 0 --- !query 16 schema -struct<(CAST(5 AS DOUBLE) / CAST(0 AS DOUBLE)):double> --- !query 16 output -NULL - - --- !query 17 -select 5 / null --- !query 17 schema -struct<(CAST(5 AS DOUBLE) / CAST(NULL AS DOUBLE)):double> --- !query 17 output -NULL - - --- !query 18 -select null / 5 --- !query 18 schema -struct<(CAST(NULL AS DOUBLE) / CAST(5 AS DOUBLE)):double> --- !query 18 output -NULL - - --- !query 19 -select 5 div 2 --- !query 19 schema -struct --- !query 19 output -2 - - --- !query 20 -select 5 div 0 --- !query 20 schema -struct --- !query 20 output -NULL - - --- !query 21 -select 5 div null --- !query 21 schema -struct --- !query 21 output -NULL - - --- !query 22 -select null div 5 --- !query 22 schema -struct --- !query 22 output -NULL - - --- !query 23 -select 1 + 2 --- !query 23 schema -struct<(1 + 2):int> --- !query 23 output -3 - - --- !query 24 -select 1 - 2 --- !query 24 schema -struct<(1 - 2):int> --- !query 24 output --1 - - --- !query 25 -select 2 * 5 --- !query 25 schema -struct<(2 * 5):int> --- !query 25 output -10 - - --- !query 26 -select 5 % 3 --- !query 26 schema -struct<(5 % 3):int> --- !query 26 output -2 - - --- !query 27 -select pmod(-7, 3) --- !query 27 schema -struct --- !query 27 output -2 diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out index bfa29d7d2d59..9c5f4554d9fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 22 +-- Number of queries: 24 -- !query 0 @@ -176,3 +176,26 @@ SELECT CAST('9223372036854775808' AS long) struct -- !query 21 output NULL + + +-- !query 22 +DESC FUNCTION boolean +-- !query 22 schema +struct +-- !query 22 output +Class: org.apache.spark.sql.catalyst.expressions.Cast +Function: boolean +Usage: boolean(expr) - Casts the value `expr` to the target data type `boolean`. + + +-- !query 23 +DESC FUNCTION EXTENDED boolean +-- !query 23 schema +struct +-- !query 23 output +Class: org.apache.spark.sql.catalyst.expressions.Cast +Extended Usage: + No example/argument for boolean. + +Function: boolean +Usage: boolean(expr) - Casts the value `expr` to the target data type `boolean`. diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index 678a3f0f0a3c..ba8bc936f0c7 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -15,7 +15,6 @@ DESC test_change -- !query 1 schema struct -- !query 1 output -# col_name data_type comment a int b string c int @@ -35,7 +34,6 @@ DESC test_change -- !query 3 schema struct -- !query 3 output -# col_name data_type comment a int b string c int @@ -55,7 +53,6 @@ DESC test_change -- !query 5 schema struct -- !query 5 output -# col_name data_type comment a int b string c int @@ -94,7 +91,6 @@ DESC test_change -- !query 8 schema struct -- !query 8 output -# col_name data_type comment a int b string c int @@ -129,7 +125,6 @@ DESC test_change -- !query 12 schema struct -- !query 12 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -148,7 +143,6 @@ DESC test_change -- !query 14 schema struct -- !query 14 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -168,7 +162,6 @@ DESC test_change -- !query 16 schema struct -- !query 16 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -193,7 +186,6 @@ DESC test_change -- !query 18 schema struct -- !query 18 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -237,7 +229,6 @@ DESC test_change -- !query 23 schema struct -- !query 23 output -# col_name data_type comment a int this is column A b string #*02?` c int diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index 60bd8e9cc99d..b5a4f5c2bf65 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -72,7 +72,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 9 @@ -81,7 +81,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 10 @@ -90,7 +90,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 11 @@ -99,7 +99,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 12 @@ -108,7 +108,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 13 @@ -125,7 +125,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 15 @@ -134,7 +134,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 16 @@ -143,7 +143,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 17 @@ -152,7 +152,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 18 @@ -161,7 +161,7 @@ SELECT db1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`db1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`db1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 19 @@ -186,7 +186,7 @@ SELECT mydb1.t1 FROM t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 22 @@ -204,7 +204,7 @@ SELECT t1 FROM mydb1.t1 struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '`t1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`t1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 24 @@ -221,7 +221,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 26 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 616421d6f2b2..7c451c2aa5b5 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -105,7 +105,7 @@ SELECT global_temp.view1.i1 FROM global_temp.view1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '`global_temp.view1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`global_temp.view1.i1`' given input columns: [view1.i1]; line 1 pos 7 -- !query 13 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index 764cad0e3943..d3ca4443cce5 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -96,7 +96,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 12 @@ -105,7 +105,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 13 @@ -154,7 +154,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 19 @@ -270,7 +270,7 @@ SELECT * FROM mydb1.t3 WHERE c1 IN struct<> -- !query 32 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t4.c3`' given input columns: [c2, c3]; line 2 pos 42 +cannot resolve '`mydb1.t4.c3`' given input columns: [t4.c2, t4.c3]; line 2 pos 42 -- !query 33 @@ -287,7 +287,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 34 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 35 @@ -296,7 +296,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 struct<> -- !query 35 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 36 @@ -313,7 +313,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 37 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 38 @@ -402,7 +402,7 @@ SELECT mydb1.t5.t5.i1 FROM mydb1.t5 struct<> -- !query 48 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i1`' given input columns: [i1, t5]; line 1 pos 7 +cannot resolve '`mydb1.t5.t5.i1`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 -- !query 49 @@ -411,7 +411,7 @@ SELECT mydb1.t5.t5.i2 FROM mydb1.t5 struct<> -- !query 49 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i2`' given input columns: [i1, t5]; line 1 pos 7 +cannot resolve '`mydb1.t5.t5.i2`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 -- !query 50 diff --git a/sql/core/src/test/resources/sql-tests/results/comparator.sql.out b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out new file mode 100644 index 000000000000..afc7b5448b7b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out @@ -0,0 +1,18 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 2 + + +-- !query 0 +select x'00' < x'0f' +-- !query 0 schema +struct<(X'00' < X'0F'):boolean> +-- !query 0 output +true + + +-- !query 1 +select x'00' < x'ff' +-- !query 1 schema +struct<(X'00' < X'FF'):boolean> +-- !query 1 output +true diff --git a/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out index 562e174fc0bb..3833c42bdfec 100644 --- a/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 13 -- !query 0 @@ -127,3 +127,14 @@ three 3 three 3 two 2 three 3 two 2 two 2 one 1 two 2 two 2 two 2 three 3 two 2 two 2 two 2 two 2 two 2 + + +-- !query 12 +SELECT * FROM nt1 CROSS JOIN nt2 ON (nt1.k > nt2.k) +-- !query 12 schema +struct +-- !query 12 output +three 3 one 1 +three 3 one 5 +two 2 one 1 +two 2 one 5 diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 032e4258500f..a28b91c77324 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 1 +-- Number of queries: 4 -- !query 0 @@ -8,3 +8,27 @@ select current_date = current_date(), current_timestamp = current_timestamp() struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean> -- !query 0 output true true + + +-- !query 1 +select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd') +-- !query 1 schema +struct +-- !query 1 output +NULL 2016-12-31 2016-12-31 + + +-- !query 2 +select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd') +-- !query 2 schema +struct +-- !query 2 output +NULL 2016-12-31 00:12:00 2016-12-31 00:00:00 + + +-- !query 3 +select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15') +-- !query 3 schema +struct +-- !query 3 output +7 5 7 NULL 6 diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out new file mode 100644 index 000000000000..51dac111029e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -0,0 +1,244 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 15 + + +-- !query 0 +CREATE TABLE t (key STRING, value STRING, ds STRING, hr INT) USING parquet + PARTITIONED BY (ds, hr) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=10) +VALUES ('k1', 100), ('k2', 200), ('k3', 300) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=11) +VALUES ('k1', 101), ('k2', 201), ('k3', 301), ('k4', 401) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +INSERT INTO TABLE t PARTITION (ds='2017-09-01', hr=5) +VALUES ('k1', 102), ('k2', 202) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10) +-- !query 4 schema +struct +-- !query 4 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=10] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 5 +ANALYZE TABLE t PARTITION (ds='2017-08-01', hr=10) COMPUTE STATISTICS +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10) +-- !query 6 schema +struct +-- !query 6 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=10] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Partition Statistics 1067 bytes, 3 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 7 +ANALYZE TABLE t PARTITION (ds='2017-08-01') COMPUTE STATISTICS +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10) +-- !query 8 schema +struct +-- !query 8 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=10] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Partition Statistics 1067 bytes, 3 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 9 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11) +-- !query 9 schema +struct +-- !query 9 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=11] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Partition Statistics 1080 bytes, 4 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 10 +ANALYZE TABLE t PARTITION (ds, hr) COMPUTE STATISTICS +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10) +-- !query 11 schema +struct +-- !query 11 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=10] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Partition Statistics 1067 bytes, 3 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 12 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11) +-- !query 12 schema +struct +-- !query 12 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=11] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Partition Statistics 1080 bytes, 4 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 13 +DESC EXTENDED t PARTITION (ds='2017-09-01', hr=5) +-- !query 13 schema +struct +-- !query 13 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-09-01, hr=5] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 +Partition Statistics 1054 bytes, 2 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 14 +DROP TABLE t +-- !query 14 schema +struct<> +-- !query 14 output + diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out new file mode 100644 index 000000000000..7873085da506 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out @@ -0,0 +1,161 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +CREATE TABLE table_with_comment (a STRING, b INT, c STRING, d STRING) USING parquet COMMENT 'added' +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +DESC FORMATTED table_with_comment +-- !query 1 schema +struct +-- !query 1 output +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table table_with_comment +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type MANAGED +Provider parquet +Comment added +Location [not included in comparison]sql/core/spark-warehouse/table_with_comment + + +-- !query 2 +ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment", "type"= "parquet") +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +DESC FORMATTED table_with_comment +-- !query 3 schema +struct +-- !query 3 output +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table table_with_comment +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type MANAGED +Provider parquet +Comment modified comment +Table Properties [type=parquet] +Location [not included in comparison]sql/core/spark-warehouse/table_with_comment + + +-- !query 4 +DROP TABLE table_with_comment +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE table_comment (a STRING, b INT) USING parquet +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +DESC FORMATTED table_comment +-- !query 6 schema +struct +-- !query 6 output +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type MANAGED +Provider parquet +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 7 +ALTER TABLE table_comment SET TBLPROPERTIES(comment = "added comment") +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +DESC formatted table_comment +-- !query 8 schema +struct +-- !query 8 output +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type MANAGED +Provider parquet +Comment added comment +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 9 +ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment') +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +DESC FORMATTED table_comment +-- !query 10 schema +struct +-- !query 10 output +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type MANAGED +Provider parquet +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 11 +DROP TABLE table_comment +-- !query 11 schema +struct<> +-- !query 11 output + diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out new file mode 100644 index 000000000000..30d0a2dc5a3f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out @@ -0,0 +1,208 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 18 + + +-- !query 0 +CREATE TEMPORARY VIEW desc_col_temp_view (key int COMMENT 'column_comment') USING PARQUET +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +DESC desc_col_temp_view key +-- !query 1 schema +struct +-- !query 1 output +col_name key +data_type int +comment column_comment + + +-- !query 2 +DESC EXTENDED desc_col_temp_view key +-- !query 2 schema +struct +-- !query 2 output +col_name key +data_type int +comment column_comment +min NULL +max NULL +num_nulls NULL +distinct_count NULL +avg_col_len NULL +max_col_len NULL + + +-- !query 3 +DESC FORMATTED desc_col_temp_view key +-- !query 3 schema +struct +-- !query 3 output +col_name key +data_type int +comment column_comment +min NULL +max NULL +num_nulls NULL +distinct_count NULL +avg_col_len NULL +max_col_len NULL + + +-- !query 4 +DESC FORMATTED desc_col_temp_view desc_col_temp_view.key +-- !query 4 schema +struct +-- !query 4 output +col_name key +data_type int +comment column_comment +min NULL +max NULL +num_nulls NULL +distinct_count NULL +avg_col_len NULL +max_col_len NULL + + +-- !query 5 +DESC desc_col_temp_view key1 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Column key1 does not exist; + + +-- !query 6 +CREATE TABLE desc_col_table (key int COMMENT 'column_comment') USING PARQUET +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +ANALYZE TABLE desc_col_table COMPUTE STATISTICS FOR COLUMNS key +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +DESC desc_col_table key +-- !query 8 schema +struct +-- !query 8 output +col_name key +data_type int +comment column_comment + + +-- !query 9 +DESC EXTENDED desc_col_table key +-- !query 9 schema +struct +-- !query 9 output +col_name key +data_type int +comment column_comment +min NULL +max NULL +num_nulls 0 +distinct_count 0 +avg_col_len 4 +max_col_len 4 + + +-- !query 10 +DESC FORMATTED desc_col_table key +-- !query 10 schema +struct +-- !query 10 output +col_name key +data_type int +comment column_comment +min NULL +max NULL +num_nulls 0 +distinct_count 0 +avg_col_len 4 +max_col_len 4 + + +-- !query 11 +CREATE TABLE desc_complex_col_table (`a.b` int, col struct) USING PARQUET +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +DESC FORMATTED desc_complex_col_table `a.b` +-- !query 12 schema +struct +-- !query 12 output +col_name a.b +data_type int +comment NULL +min NULL +max NULL +num_nulls NULL +distinct_count NULL +avg_col_len NULL +max_col_len NULL + + +-- !query 13 +DESC FORMATTED desc_complex_col_table col +-- !query 13 schema +struct +-- !query 13 output +col_name col +data_type struct +comment NULL +min NULL +max NULL +num_nulls NULL +distinct_count NULL +avg_col_len NULL +max_col_len NULL + + +-- !query 14 +DESC FORMATTED desc_complex_col_table col.x +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +DESC TABLE COLUMN command does not support nested data types: col.x; + + +-- !query 15 +DROP VIEW desc_col_temp_view +-- !query 15 schema +struct<> +-- !query 15 output + + + +-- !query 16 +DROP TABLE desc_col_table +-- !query 16 schema +struct<> +-- !query 16 output + + + +-- !query 17 +DROP TABLE desc_complex_col_table +-- !query 17 schema +struct<> +-- !query 17 output + diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index de10b29f3c65..8c908b762505 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,11 +1,13 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 31 +-- Number of queries: 36 -- !query 0 CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + OPTIONS (a '1', b '2') PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS COMMENT 'table_comment' + TBLPROPERTIES (t 'test') -- !query 0 schema struct<> -- !query 0 output @@ -42,7 +44,7 @@ struct<> -- !query 4 -ALTER TABLE t ADD PARTITION (c='Us', d=1) +ALTER TABLE t SET TBLPROPERTIES (e = '3') -- !query 4 schema struct<> -- !query 4 output @@ -50,11 +52,18 @@ struct<> -- !query 5 -DESCRIBE t +ALTER TABLE t ADD PARTITION (c='Us', d=1) -- !query 5 schema -struct +struct<> -- !query 5 output -# col_name data_type comment + + + +-- !query 6 +DESCRIBE t +-- !query 6 schema +struct +-- !query 6 output a string b int c string @@ -65,12 +74,11 @@ c string d string --- !query 6 +-- !query 7 DESC default.t --- !query 6 schema +-- !query 7 schema struct --- !query 6 output -# col_name data_type comment +-- !query 7 output a string b int c string @@ -81,12 +89,11 @@ c string d string --- !query 7 +-- !query 8 DESC TABLE t --- !query 7 schema +-- !query 8 schema struct --- !query 7 output -# col_name data_type comment +-- !query 8 output a string b int c string @@ -97,12 +104,11 @@ c string d string --- !query 8 +-- !query 9 DESC FORMATTED t --- !query 8 schema +-- !query 9 schema struct --- !query 8 output -# col_name data_type comment +-- !query 9 output a string b int c string @@ -115,24 +121,66 @@ d string # Detailed Table Information Database default Table t -Created [not included in comparison] +Created Time [not included in comparison] Last Access [not included in comparison] +Created By [not included in comparison] Type MANAGED Provider parquet Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] Comment table_comment +Table Properties [t=test, e=3] Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] Partition Provider Catalog --- !query 9 +-- !query 10 DESC EXTENDED t --- !query 9 schema +-- !query 10 schema struct --- !query 9 output +-- !query 10 output +a string +b int +c string +d string +# Partition Information # col_name data_type comment +c string +d string + +# Detailed Table Information +Database default +Table t +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Table Properties [t=test, e=3] +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] +Partition Provider Catalog + + +-- !query 11 +ALTER TABLE t UNSET TBLPROPERTIES (e) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +DESC EXTENDED t +-- !query 12 schema +struct +-- !query 12 output a string b int c string @@ -145,24 +193,65 @@ d string # Detailed Table Information Database default Table t -Created [not included in comparison] +Created Time [not included in comparison] Last Access [not included in comparison] +Created By [not included in comparison] Type MANAGED Provider parquet Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] Comment table_comment +Table Properties [t=test] Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] Partition Provider Catalog --- !query 10 -DESC t PARTITION (c='Us', d=1) --- !query 10 schema +-- !query 13 +ALTER TABLE t UNSET TBLPROPERTIES (comment) +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +DESC EXTENDED t +-- !query 14 schema struct --- !query 10 output +-- !query 14 output +a string +b int +c string +d string +# Partition Information # col_name data_type comment +c string +d string + +# Detailed Table Information +Database default +Table t +Created Time [not included in comparison] +Last Access [not included in comparison] +Created By [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Table Properties [t=test] +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] +Partition Provider Catalog + + +-- !query 15 +DESC t PARTITION (c='Us', d=1) +-- !query 15 schema +struct +-- !query 15 output a string b int c string @@ -173,12 +262,11 @@ c string d string --- !query 11 +-- !query 16 DESC EXTENDED t PARTITION (c='Us', d=1) --- !query 11 schema +-- !query 16 schema struct --- !query 11 output -# col_name data_type comment +-- !query 16 output a string b int c string @@ -193,20 +281,21 @@ Database default Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 +Storage Properties [a=1, b=2] # Storage Information Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] -Location [not included in comparison]sql/core/spark-warehouse/t +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] --- !query 12 +-- !query 17 DESC FORMATTED t PARTITION (c='Us', d=1) --- !query 12 schema +-- !query 17 schema struct --- !query 12 output -# col_name data_type comment +-- !query 17 output a string b int c string @@ -221,39 +310,41 @@ Database default Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 +Storage Properties [a=1, b=2] # Storage Information Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] -Location [not included in comparison]sql/core/spark-warehouse/t +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] --- !query 13 +-- !query 18 DESC t PARTITION (c='Us', d=2) --- !query 13 schema +-- !query 18 schema struct<> --- !query 13 output +-- !query 18 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 14 +-- !query 19 DESC t PARTITION (c='Us') --- !query 14 schema +-- !query 19 schema struct<> --- !query 14 output +-- !query 19 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 15 +-- !query 20 DESC t PARTITION (c='Us', d) --- !query 15 schema +-- !query 20 schema struct<> --- !query 15 output +-- !query 20 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -263,60 +354,55 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 16 +-- !query 21 DESC temp_v --- !query 16 schema +-- !query 21 schema struct --- !query 16 output -# col_name data_type comment +-- !query 21 output a string b int c string d string --- !query 17 +-- !query 22 DESC TABLE temp_v --- !query 17 schema +-- !query 22 schema struct --- !query 17 output -# col_name data_type comment +-- !query 22 output a string b int c string d string --- !query 18 +-- !query 23 DESC FORMATTED temp_v --- !query 18 schema +-- !query 23 schema struct --- !query 18 output -# col_name data_type comment +-- !query 23 output a string b int c string d string --- !query 19 +-- !query 24 DESC EXTENDED temp_v --- !query 19 schema +-- !query 24 schema struct --- !query 19 output -# col_name data_type comment +-- !query 24 output a string b int c string d string --- !query 20 +-- !query 25 DESC temp_Data_Source_View --- !query 20 schema +-- !query 25 schema struct --- !query 20 output -# col_name data_type comment +-- !query 25 output intType int test comment test1 stringType string dateType date @@ -335,45 +421,42 @@ arrayType array structType struct --- !query 21 +-- !query 26 DESC temp_v PARTITION (c='Us', d=1) --- !query 21 schema +-- !query 26 schema struct<> --- !query 21 output +-- !query 26 output org.apache.spark.sql.AnalysisException DESC PARTITION is not allowed on a temporary view: temp_v; --- !query 22 +-- !query 27 DESC v --- !query 22 schema +-- !query 27 schema struct --- !query 22 output -# col_name data_type comment +-- !query 27 output a string b int c string d string --- !query 23 +-- !query 28 DESC TABLE v --- !query 23 schema +-- !query 28 schema struct --- !query 23 output -# col_name data_type comment +-- !query 28 output a string b int c string d string --- !query 24 +-- !query 29 DESC FORMATTED v --- !query 24 schema +-- !query 29 schema struct --- !query 24 output -# col_name data_type comment +-- !query 29 output a string b int c string @@ -382,21 +465,21 @@ d string # Detailed Table Information Database default Table v -Created [not included in comparison] +Created Time [not included in comparison] Last Access [not included in comparison] +Created By [not included in comparison] Type VIEW View Text SELECT * FROM t View Default Database default View Query Output Columns [a, b, c, d] -Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] +Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] --- !query 25 +-- !query 30 DESC EXTENDED v --- !query 25 schema +-- !query 30 schema struct --- !query 25 output -# col_name data_type comment +-- !query 30 output a string b int c string @@ -405,51 +488,52 @@ d string # Detailed Table Information Database default Table v -Created [not included in comparison] +Created Time [not included in comparison] Last Access [not included in comparison] +Created By [not included in comparison] Type VIEW View Text SELECT * FROM t View Default Database default View Query Output Columns [a, b, c, d] -Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] +Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] --- !query 26 +-- !query 31 DESC v PARTITION (c='Us', d=1) --- !query 26 schema +-- !query 31 schema struct<> --- !query 26 output +-- !query 31 output org.apache.spark.sql.AnalysisException DESC PARTITION is not allowed on a view: v; --- !query 27 +-- !query 32 DROP TABLE t --- !query 27 schema +-- !query 32 schema struct<> --- !query 27 output +-- !query 32 output --- !query 28 +-- !query 33 DROP VIEW temp_v --- !query 28 schema +-- !query 33 schema struct<> --- !query 28 output +-- !query 33 output --- !query 29 +-- !query 34 DROP VIEW temp_Data_Source_View --- !query 29 schema +-- !query 34 schema struct<> --- !query 29 output +-- !query 34 output --- !query 30 +-- !query 35 DROP VIEW v --- !query 30 schema +-- !query 35 schema struct<> --- !query 30 output +-- !query 35 output diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index 825e8f5488c8..ce7a16a4d0c8 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 29 -- !query 0 @@ -328,3 +328,50 @@ struct<> -- !query 25 output org.apache.spark.sql.AnalysisException grouping__id is deprecated; use grouping_id() instead; + + +-- !query 26 +SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2) +-- !query 26 schema +struct +-- !query 26 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL 1 3 +NULL 2 0 +NULL NULL 3 + + +-- !query 27 +SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b) +-- !query 27 schema +struct +-- !query 27 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL NULL 3 + + +-- !query 28 +SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) +-- !query 28 schema +struct<(a + b):int,k:int,sum((a - b)):bigint> +-- !query 28 output +NULL 1 3 +NULL 2 0 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 6bf9dff883c1..986bb01c13fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 22 +-- Number of queries: 25 -- !query 0 @@ -202,4 +202,28 @@ SELECT a AS k, COUNT(b) FROM testData GROUP BY k struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 +cannot resolve '`k`' given input columns: [testdata.a, testdata.b]; line 1 pos 47 + + +-- !query 22 +SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a +-- !query 22 schema +struct +-- !query 22 output + + + +-- !query 23 +SELECT COUNT(1) FROM testData WHERE false +-- !query 23 schema +struct +-- !query 23 output +0 + + +-- !query 24 +SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t +-- !query 24 schema +struct<1:int> +-- !query 24 output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 4e80f0bda551..c065ce501292 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 18 -- !query 0 @@ -151,3 +151,33 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991- struct> -- !query 16 output 1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0] + + +-- !query 17 +EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null) +-- !query 17 schema +struct +-- !query 17 output +== Parsed Logical Plan == +'Project [*] ++- 'Join Cross + :- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] + +- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] + +== Analyzed Logical Plan == +col1: string, col2: int, col1: string, col2: int +Project [col1#x, col2#x, col1#x, col2#x] ++- Join Cross + :- LocalRelation [col1#x, col2#x] + +- LocalRelation [col1#x, col2#x] + +== Optimized Logical Plan == +Join Cross +:- LocalRelation [col1#x, col2#x] ++- LocalRelation [col1#x, col2#x] + +== Physical Plan == +BroadcastNestedLoopJoin BuildRight, Cross +:- LocalTableScan [col1#x, col2#x] ++- BroadcastExchange IdentityBroadcastMode + +- LocalTableScan [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index fedabaee2237..d9dc728a18e8 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 26 -- !query 0 @@ -26,7 +26,17 @@ Extended Usage: {"time":"26/08/2015"} > SELECT to_json(array(named_struct('a', 1, 'b', 2)); [{"a":1,"b":2}] + > SELECT to_json(map('a', named_struct('b', 1))); + {"a":{"b":1}} + > SELECT to_json(map(named_struct('a', 1),named_struct('b', 2))); + {"[1]":{"b":2}} + > SELECT to_json(map('a', 1)); + {"a":1} + > SELECT to_json(array((map('a', 1)))); + [{"a":1}] + Since: 2.2.0 + Function: to_json Usage: to_json(expr[, options]) - Returns a json string with a given struct value @@ -56,47 +66,87 @@ struct -- !query 5 -select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +select to_json(map(named_struct('a', 1, 'b', 2), named_struct('a', 1, 'b', 2))) -- !query 5 schema -struct<> +struct -- !query 5 output +{"[1,2]":{"a":1,"b":2}} + + +-- !query 6 +select to_json(map('a', named_struct('a', 1, 'b', 2))) +-- !query 6 schema +struct +-- !query 6 output +{"a":{"a":1,"b":2}} + + +-- !query 7 +select to_json(map('a', 1)) +-- !query 7 schema +struct +-- !query 7 output +{"a":1} + + +-- !query 8 +select to_json(array(map('a',1))) +-- !query 8 schema +struct +-- !query 8 output +[{"a":1}] + + +-- !query 9 +select to_json(array(map('a',1), map('b',2))) +-- !query 9 schema +struct +-- !query 9 output +[{"a":1},{"b":2}] + + +-- !query 10 +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +-- !query 10 schema +struct<> +-- !query 10 output org.apache.spark.sql.AnalysisException Must use a map() function for options;; line 1 pos 7 --- !query 6 +-- !query 11 select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) --- !query 6 schema +-- !query 11 schema struct<> --- !query 6 output +-- !query 11 output org.apache.spark.sql.AnalysisException A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 --- !query 7 +-- !query 12 select to_json() --- !query 7 schema +-- !query 12 schema struct<> --- !query 7 output +-- !query 12 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function to_json; line 1 pos 7 --- !query 8 +-- !query 13 describe function from_json --- !query 8 schema +-- !query 13 schema struct --- !query 8 output +-- !query 13 output Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs Function: from_json Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. --- !query 9 +-- !query 14 describe function extended from_json --- !query 9 schema +-- !query 14 schema struct --- !query 9 output +-- !query 14 output Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs Extended Usage: Examples: @@ -105,40 +155,42 @@ Extended Usage: > SELECT from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); {"time":"2015-08-26 00:00:00.0"} + Since: 2.2.0 + Function: from_json Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. --- !query 10 +-- !query 15 select from_json('{"a":1}', 'a INT') --- !query 10 schema +-- !query 15 schema struct> --- !query 10 output +-- !query 15 output {"a":1} --- !query 11 +-- !query 16 select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) --- !query 11 schema +-- !query 16 schema struct> --- !query 11 output +-- !query 16 output {"time":2015-08-26 00:00:00.0} --- !query 12 +-- !query 17 select from_json('{"a":1}', 1) --- !query 12 schema +-- !query 17 schema struct<> --- !query 12 output +-- !query 17 output org.apache.spark.sql.AnalysisException Expected a string literal instead of 1;; line 1 pos 7 --- !query 13 +-- !query 18 select from_json('{"a":1}', 'a InvalidType') --- !query 13 schema +-- !query 18 schema struct<> --- !query 13 output +-- !query 18 output org.apache.spark.sql.AnalysisException DataType invalidtype is not supported.(line 1, pos 2) @@ -149,28 +201,60 @@ a InvalidType ; line 1 pos 7 --- !query 14 +-- !query 19 select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')) --- !query 14 schema +-- !query 19 schema struct<> --- !query 14 output +-- !query 19 output org.apache.spark.sql.AnalysisException Must use a map() function for options;; line 1 pos 7 --- !query 15 +-- !query 20 select from_json('{"a":1}', 'a INT', map('mode', 1)) --- !query 15 schema +-- !query 20 schema struct<> --- !query 15 output +-- !query 20 output org.apache.spark.sql.AnalysisException A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 --- !query 16 +-- !query 21 select from_json() --- !query 16 schema +-- !query 21 schema struct<> --- !query 16 output +-- !query 21 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_json; line 1 pos 7 + + +-- !query 22 +SELECT json_tuple('{"a" : 1, "b" : 2}', CAST(NULL AS STRING), 'b', CAST(NULL AS STRING), 'a') +-- !query 22 schema +struct +-- !query 22 output +NULL 2 NULL 1 + + +-- !query 23 +CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a') +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable +-- !query 24 schema +struct +-- !query 24 output +2 NULL 1 + + +-- !query 25 +DROP VIEW IF EXISTS jsonTable +-- !query 25 schema +struct<> +-- !query 25 output + diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index cb4e4d04810d..146abe6cbd05 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,9 +1,9 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 12 -- !query 0 -select * from testdata limit 2 +SELECT * FROM testdata LIMIT 2 -- !query 0 schema struct -- !query 0 output @@ -12,7 +12,7 @@ struct -- !query 1 -select * from arraydata limit 2 +SELECT * FROM arraydata LIMIT 2 -- !query 1 schema struct,nestedarraycol:array>> -- !query 1 output @@ -21,7 +21,7 @@ struct,nestedarraycol:array>> -- !query 2 -select * from mapdata limit 2 +SELECT * FROM mapdata LIMIT 2 -- !query 2 schema struct> -- !query 2 output @@ -30,7 +30,7 @@ struct> -- !query 3 -select * from testdata limit 2 + 1 +SELECT * FROM testdata LIMIT 2 + 1 -- !query 3 schema struct -- !query 3 output @@ -40,7 +40,7 @@ struct -- !query 4 -select * from testdata limit CAST(1 AS int) +SELECT * FROM testdata LIMIT CAST(1 AS int) -- !query 4 schema struct -- !query 4 output @@ -48,7 +48,7 @@ struct -- !query 5 -select * from testdata limit -1 +SELECT * FROM testdata LIMIT -1 -- !query 5 schema struct<> -- !query 5 output @@ -57,35 +57,53 @@ The limit expression must be equal to or greater than 0, but got -1; -- !query 6 -select * from testdata limit key > 3 +SELECT * FROM testData TABLESAMPLE (-1 ROWS) -- !query 6 schema struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); +The limit expression must be equal to or greater than 0, but got -1; -- !query 7 -select * from testdata limit true +SELECT * FROM testdata LIMIT key > 3 -- !query 7 schema struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got boolean; +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); -- !query 8 -select * from testdata limit 'a' +SELECT * FROM testdata LIMIT true -- !query 8 schema struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got string; +The limit expression must be integer type, but got boolean; -- !query 9 -select * from (select * from range(10) limit 5) where id > 3 +SELECT * FROM testdata LIMIT 'a' -- !query 9 schema -struct +struct<> -- !query 9 output +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; + + +-- !query 10 +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 +-- !query 10 schema +struct +-- !query 10 output 4 + + +-- !query 11 +SELECT * FROM testdata WHERE key < 3 LIMIT ALL +-- !query 11 schema +struct +-- !query 11 output +1 1 +2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out new file mode 100644 index 000000000000..237b618a8b90 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -0,0 +1,486 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 59 + + +-- !query 0 +select -100 +-- !query 0 schema +struct<-100:int> +-- !query 0 output +-100 + + +-- !query 1 +select +230 +-- !query 1 schema +struct<230:int> +-- !query 1 output +230 + + +-- !query 2 +select -5.2 +-- !query 2 schema +struct<-5.2:decimal(2,1)> +-- !query 2 output +-5.2 + + +-- !query 3 +select +6.8e0 +-- !query 3 schema +struct<6.8:decimal(2,1)> +-- !query 3 output +6.8 + + +-- !query 4 +select -key, +key from testdata where key = 2 +-- !query 4 schema +struct<(- key):int,key:int> +-- !query 4 output +-2 2 + + +-- !query 5 +select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1 +-- !query 5 schema +struct<(- (key + 1)):int,((- key) + 1):int,(key + 5):int> +-- !query 5 output +-2 0 6 + + +-- !query 6 +select -max(key), +max(key) from testdata +-- !query 6 schema +struct<(- max(key)):int,max(key):int> +-- !query 6 output +-100 100 + + +-- !query 7 +select - (-10) +-- !query 7 schema +struct<(- -10):int> +-- !query 7 output +10 + + +-- !query 8 +select + (-key) from testdata where key = 32 +-- !query 8 schema +struct<(- key):int> +-- !query 8 output +-32 + + +-- !query 9 +select - (+max(key)) from testdata +-- !query 9 schema +struct<(- max(key)):int> +-- !query 9 output +-100 + + +-- !query 10 +select - - 3 +-- !query 10 schema +struct<(- -3):int> +-- !query 10 output +3 + + +-- !query 11 +select - + 20 +-- !query 11 schema +struct<(- 20):int> +-- !query 11 output +-20 + + +-- !query 12 +select + + 100 +-- !query 12 schema +struct<100:int> +-- !query 12 output +100 + + +-- !query 13 +select - - max(key) from testdata +-- !query 13 schema +struct<(- (- max(key))):int> +-- !query 13 output +100 + + +-- !query 14 +select + - key from testdata where key = 33 +-- !query 14 schema +struct<(- key):int> +-- !query 14 output +-33 + + +-- !query 15 +select 5 / 2 +-- !query 15 schema +struct<(CAST(5 AS DOUBLE) / CAST(2 AS DOUBLE)):double> +-- !query 15 output +2.5 + + +-- !query 16 +select 5 / 0 +-- !query 16 schema +struct<(CAST(5 AS DOUBLE) / CAST(0 AS DOUBLE)):double> +-- !query 16 output +NULL + + +-- !query 17 +select 5 / null +-- !query 17 schema +struct<(CAST(5 AS DOUBLE) / CAST(NULL AS DOUBLE)):double> +-- !query 17 output +NULL + + +-- !query 18 +select null / 5 +-- !query 18 schema +struct<(CAST(NULL AS DOUBLE) / CAST(5 AS DOUBLE)):double> +-- !query 18 output +NULL + + +-- !query 19 +select 5 div 2 +-- !query 19 schema +struct +-- !query 19 output +2 + + +-- !query 20 +select 5 div 0 +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select 5 div null +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select null div 5 +-- !query 22 schema +struct +-- !query 22 output +NULL + + +-- !query 23 +select 1 + 2 +-- !query 23 schema +struct<(1 + 2):int> +-- !query 23 output +3 + + +-- !query 24 +select 1 - 2 +-- !query 24 schema +struct<(1 - 2):int> +-- !query 24 output +-1 + + +-- !query 25 +select 2 * 5 +-- !query 25 schema +struct<(2 * 5):int> +-- !query 25 output +10 + + +-- !query 26 +select 5 % 3 +-- !query 26 schema +struct<(5 % 3):int> +-- !query 26 output +2 + + +-- !query 27 +select pmod(-7, 3) +-- !query 27 schema +struct +-- !query 27 output +2 + + +-- !query 28 +explain select 'a' || 1 + 2 +-- !query 28 schema +struct +-- !query 28 output +== Physical Plan == +*Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x] ++- Scan OneRowRelation[] + + +-- !query 29 +explain select 1 - 2 || 'b' +-- !query 29 schema +struct +-- !query 29 output +== Physical Plan == +*Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x] ++- Scan OneRowRelation[] + + +-- !query 30 +explain select 2 * 4 + 3 || 'b' +-- !query 30 schema +struct +-- !query 30 output +== Physical Plan == +*Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x] ++- Scan OneRowRelation[] + + +-- !query 31 +explain select 3 + 1 || 'a' || 4 / 2 +-- !query 31 schema +struct +-- !query 31 output +== Physical Plan == +*Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x] ++- Scan OneRowRelation[] + + +-- !query 32 +explain select 1 == 1 OR 'a' || 'b' == 'ab' +-- !query 32 schema +struct +-- !query 32 output +== Physical Plan == +*Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x] ++- Scan OneRowRelation[] + + +-- !query 33 +explain select 'a' || 'c' == 'ac' AND 2 == 3 +-- !query 33 schema +struct +-- !query 33 output +== Physical Plan == +*Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x] ++- Scan OneRowRelation[] + + +-- !query 34 +select cot(1) +-- !query 34 schema +struct +-- !query 34 output +0.6420926159343306 + + +-- !query 35 +select cot(null) +-- !query 35 schema +struct +-- !query 35 output +NULL + + +-- !query 36 +select cot(0) +-- !query 36 schema +struct +-- !query 36 output +Infinity + + +-- !query 37 +select cot(-1) +-- !query 37 schema +struct +-- !query 37 output +-0.6420926159343306 + + +-- !query 38 +select ceiling(0) +-- !query 38 schema +struct +-- !query 38 output +0 + + +-- !query 39 +select ceiling(1) +-- !query 39 schema +struct +-- !query 39 output +1 + + +-- !query 40 +select ceil(1234567890123456) +-- !query 40 schema +struct +-- !query 40 output +1234567890123456 + + +-- !query 41 +select ceiling(1234567890123456) +-- !query 41 schema +struct +-- !query 41 output +1234567890123456 + + +-- !query 42 +select ceil(0.01) +-- !query 42 schema +struct +-- !query 42 output +1 + + +-- !query 43 +select ceiling(-0.10) +-- !query 43 schema +struct +-- !query 43 output +0 + + +-- !query 44 +select floor(0) +-- !query 44 schema +struct +-- !query 44 output +0 + + +-- !query 45 +select floor(1) +-- !query 45 schema +struct +-- !query 45 output +1 + + +-- !query 46 +select floor(1234567890123456) +-- !query 46 schema +struct +-- !query 46 output +1234567890123456 + + +-- !query 47 +select floor(0.01) +-- !query 47 schema +struct +-- !query 47 output +0 + + +-- !query 48 +select floor(-0.10) +-- !query 48 schema +struct +-- !query 48 output +-1 + + +-- !query 49 +select 1 > 0.00001 +-- !query 49 schema +struct<(CAST(1 AS BIGINT) > 0):boolean> +-- !query 49 output +true + + +-- !query 50 +select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null) +-- !query 50 schema +struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> +-- !query 50 output +1 NULL 0 NULL NULL NULL + + +-- !query 51 +select BIT_LENGTH('abc') +-- !query 51 schema +struct +-- !query 51 output +24 + + +-- !query 52 +select CHAR_LENGTH('abc') +-- !query 52 schema +struct +-- !query 52 output +3 + + +-- !query 53 +select CHARACTER_LENGTH('abc') +-- !query 53 schema +struct +-- !query 53 output +3 + + +-- !query 54 +select OCTET_LENGTH('abc') +-- !query 54 schema +struct +-- !query 54 output +3 + + +-- !query 55 +select abs(-3.13), abs('-2.19') +-- !query 55 schema +struct +-- !query 55 output +3.13 2.19 + + +-- !query 56 +select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) +-- !query 56 schema +struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> +-- !query 56 output +-1.11 -1.11 1.11 1.11 + + +-- !query 57 +select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null) +-- !query 57 schema +struct +-- !query 57 output +1 0 NULL NULL NULL NULL + + +-- !query 58 +select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)) +-- !query 58 schema +struct +-- !query 58 output +NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out new file mode 100644 index 000000000000..8e7e04c8e1c4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -0,0 +1,218 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 27 + + +-- !query 0 +select 1 = 1 +-- !query 0 schema +struct<(1 = 1):boolean> +-- !query 0 output +true + + +-- !query 1 +select 1 = '1' +-- !query 1 schema +struct<(1 = CAST(1 AS INT)):boolean> +-- !query 1 output +true + + +-- !query 2 +select 1.0 = '1' +-- !query 2 schema +struct<(1.0 = CAST(1 AS DECIMAL(2,1))):boolean> +-- !query 2 output +true + + +-- !query 3 +select 1 > '1' +-- !query 3 schema +struct<(1 > CAST(1 AS INT)):boolean> +-- !query 3 output +false + + +-- !query 4 +select 2 > '1.0' +-- !query 4 schema +struct<(2 > CAST(1.0 AS INT)):boolean> +-- !query 4 output +true + + +-- !query 5 +select 2 > '2.0' +-- !query 5 schema +struct<(2 > CAST(2.0 AS INT)):boolean> +-- !query 5 output +false + + +-- !query 6 +select 2 > '2.2' +-- !query 6 schema +struct<(2 > CAST(2.2 AS INT)):boolean> +-- !query 6 output +false + + +-- !query 7 +select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') +-- !query 7 schema +struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> +-- !query 7 output +false + + +-- !query 8 +select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' +-- !query 8 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean> +-- !query 8 output +false + + +-- !query 9 +select 1 >= '1' +-- !query 9 schema +struct<(1 >= CAST(1 AS INT)):boolean> +-- !query 9 output +true + + +-- !query 10 +select 2 >= '1.0' +-- !query 10 schema +struct<(2 >= CAST(1.0 AS INT)):boolean> +-- !query 10 output +true + + +-- !query 11 +select 2 >= '2.0' +-- !query 11 schema +struct<(2 >= CAST(2.0 AS INT)):boolean> +-- !query 11 output +true + + +-- !query 12 +select 2.0 >= '2.2' +-- !query 12 schema +struct<(2.0 >= CAST(2.2 AS DECIMAL(2,1))):boolean> +-- !query 12 output +false + + +-- !query 13 +select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') +-- !query 13 schema +struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> +-- !query 13 output +true + + +-- !query 14 +select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' +-- !query 14 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean> +-- !query 14 output +false + + +-- !query 15 +select 1 < '1' +-- !query 15 schema +struct<(1 < CAST(1 AS INT)):boolean> +-- !query 15 output +false + + +-- !query 16 +select 2 < '1.0' +-- !query 16 schema +struct<(2 < CAST(1.0 AS INT)):boolean> +-- !query 16 output +false + + +-- !query 17 +select 2 < '2.0' +-- !query 17 schema +struct<(2 < CAST(2.0 AS INT)):boolean> +-- !query 17 output +false + + +-- !query 18 +select 2.0 < '2.2' +-- !query 18 schema +struct<(2.0 < CAST(2.2 AS DECIMAL(2,1))):boolean> +-- !query 18 output +true + + +-- !query 19 +select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') +-- !query 19 schema +struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> +-- !query 19 output +false + + +-- !query 20 +select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' +-- !query 20 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean> +-- !query 20 output +true + + +-- !query 21 +select 1 <= '1' +-- !query 21 schema +struct<(1 <= CAST(1 AS INT)):boolean> +-- !query 21 output +true + + +-- !query 22 +select 2 <= '1.0' +-- !query 22 schema +struct<(2 <= CAST(1.0 AS INT)):boolean> +-- !query 22 output +false + + +-- !query 23 +select 2 <= '2.0' +-- !query 23 schema +struct<(2 <= CAST(2.0 AS INT)):boolean> +-- !query 23 output +true + + +-- !query 24 +select 2.0 <= '2.2' +-- !query 24 schema +struct<(2.0 <= CAST(2.2 AS DECIMAL(2,1))):boolean> +-- !query 24 output +true + + +-- !query 25 +select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') +-- !query 25 schema +struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> +-- !query 25 output +true + + +-- !query 26 +select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' +-- !query 26 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> +-- !query 26 output +true diff --git a/sql/core/src/test/resources/sql-tests/results/query_regex_column.sql.out b/sql/core/src/test/resources/sql-tests/results/query_regex_column.sql.out new file mode 100644 index 000000000000..2dade86f35df --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/query_regex_column.sql.out @@ -0,0 +1,313 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 34 + + +-- !query 0 +set spark.sql.parser.quotedRegexColumnNames=false +-- !query 0 schema +struct +-- !query 0 output +spark.sql.parser.quotedRegexColumnNames false + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, "1", "11"), (2, "2", "22"), (3, "3", "33"), (4, "4", "44"), (5, "5", "55"), (6, "6", "66") +AS testData(key, value1, value2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE OR REPLACE TEMPORARY VIEW testData2 AS SELECT * FROM VALUES +(1, 1, 1, 2), (1, 2, 1, 2), (2, 1, 2, 3), (2, 2, 2, 3), (3, 1, 3, 4), (3, 2, 3, 4) +AS testData2(A, B, c, d) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT `(a)?+.+` FROM testData2 WHERE a = 1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a)?+.+`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 7 + + +-- !query 4 +SELECT t.`(a)?+.+` FROM testData2 t WHERE a = 1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`(a)?+.+`' given input columns: [t.A, t.B, t.c, t.d]; line 1 pos 7 + + +-- !query 5 +SELECT `(a|b)` FROM testData2 WHERE a = 2 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a|b)`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 7 + + +-- !query 6 +SELECT `(a|b)?+.+` FROM testData2 WHERE a = 2 +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a|b)?+.+`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 7 + + +-- !query 7 +SELECT SUM(`(a|b)?+.+`) FROM testData2 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a|b)?+.+`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 11 + + +-- !query 8 +SELECT SUM(`(a)`) FROM testData2 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a)`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 11 + + +-- !query 9 +set spark.sql.parser.quotedRegexColumnNames=true +-- !query 9 schema +struct +-- !query 9 output +spark.sql.parser.quotedRegexColumnNames true + + +-- !query 10 +SELECT `(a)?+.+` FROM testData2 WHERE a = 1 +-- !query 10 schema +struct +-- !query 10 output +1 1 2 +2 1 2 + + +-- !query 11 +SELECT `(A)?+.+` FROM testData2 WHERE a = 1 +-- !query 11 schema +struct +-- !query 11 output +1 1 2 +2 1 2 + + +-- !query 12 +SELECT t.`(a)?+.+` FROM testData2 t WHERE a = 1 +-- !query 12 schema +struct +-- !query 12 output +1 1 2 +2 1 2 + + +-- !query 13 +SELECT t.`(A)?+.+` FROM testData2 t WHERE a = 1 +-- !query 13 schema +struct +-- !query 13 output +1 1 2 +2 1 2 + + +-- !query 14 +SELECT `(a|B)` FROM testData2 WHERE a = 2 +-- !query 14 schema +struct +-- !query 14 output +2 1 +2 2 + + +-- !query 15 +SELECT `(A|b)` FROM testData2 WHERE a = 2 +-- !query 15 schema +struct +-- !query 15 output +2 1 +2 2 + + +-- !query 16 +SELECT `(a|B)?+.+` FROM testData2 WHERE a = 2 +-- !query 16 schema +struct +-- !query 16 output +2 3 +2 3 + + +-- !query 17 +SELECT `(A|b)?+.+` FROM testData2 WHERE a = 2 +-- !query 17 schema +struct +-- !query 17 output +2 3 +2 3 + + +-- !query 18 +SELECT `(e|f)` FROM testData2 +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SELECT t.`(e|f)` FROM testData2 t +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +SELECT p.`(KEY)?+.+`, b, testdata2.`(b)?+.+` FROM testData p join testData2 ON p.key = testData2.a WHERE key < 3 +-- !query 20 schema +struct +-- !query 20 output +1 11 1 1 1 2 +1 11 2 1 1 2 +2 22 1 2 2 3 +2 22 2 2 2 3 + + +-- !query 21 +SELECT p.`(key)?+.+`, b, testdata2.`(b)?+.+` FROM testData p join testData2 ON p.key = testData2.a WHERE key < 3 +-- !query 21 schema +struct +-- !query 21 output +1 11 1 1 1 2 +1 11 2 1 1 2 +2 22 1 2 2 3 +2 22 2 2 2 3 + + +-- !query 22 +set spark.sql.caseSensitive=true +-- !query 22 schema +struct +-- !query 22 output +spark.sql.caseSensitive true + + +-- !query 23 +CREATE OR REPLACE TEMPORARY VIEW testdata3 AS SELECT * FROM VALUES +(0, 1), (1, 2), (2, 3), (3, 4) +AS testdata3(a, b) +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +SELECT `(A)?+.+` FROM testdata3 +-- !query 24 schema +struct +-- !query 24 output +0 1 +1 2 +2 3 +3 4 + + +-- !query 25 +SELECT `(a)?+.+` FROM testdata3 +-- !query 25 schema +struct +-- !query 25 output +1 +2 +3 +4 + + +-- !query 26 +SELECT `(A)?+.+` FROM testdata3 WHERE a > 1 +-- !query 26 schema +struct +-- !query 26 output +2 3 +3 4 + + +-- !query 27 +SELECT `(a)?+.+` FROM testdata3 where `a` > 1 +-- !query 27 schema +struct +-- !query 27 output +3 +4 + + +-- !query 28 +SELECT SUM(`a`) FROM testdata3 +-- !query 28 schema +struct +-- !query 28 output +6 + + +-- !query 29 +SELECT SUM(`(a)`) FROM testdata3 +-- !query 29 schema +struct +-- !query 29 output +6 + + +-- !query 30 +SELECT SUM(`(a)?+.+`) FROM testdata3 +-- !query 30 schema +struct +-- !query 30 output +10 + + +-- !query 31 +SELECT SUM(a) FROM testdata3 GROUP BY `a` +-- !query 31 schema +struct +-- !query 31 output +0 +1 +2 +3 + + +-- !query 32 +SELECT SUM(a) FROM testdata3 GROUP BY `(a)` +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a)`' given input columns: [testdata3.a, testdata3.b]; line 1 pos 38 + + +-- !query 33 +SELECT SUM(a) FROM testdata3 GROUP BY `(a)?+.+` +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a)?+.+`' given input columns: [testdata3.a, testdata3.b]; line 1 pos 38 diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 8f2a54f7c24e..975bb0612474 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -119,8 +119,9 @@ SHOW TABLE EXTENDED LIKE 'show_t*' struct -- !query 12 output show_t3 true Table: show_t3 -Created [not included in comparison] +Created Time [not included in comparison] Last Access [not included in comparison] +Created By [not included in comparison] Type: VIEW Schema: root |-- e: integer (nullable = true) @@ -128,8 +129,9 @@ Schema: root showdb show_t1 false Database: showdb Table: show_t1 -Created [not included in comparison] +Created Time [not included in comparison] Last Access [not included in comparison] +Created By [not included in comparison] Type: MANAGED Provider: parquet Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1 @@ -144,8 +146,9 @@ Schema: root showdb show_t2 false Database: showdb Table: show_t2 -Created [not included in comparison] +Created Time [not included in comparison] Last Access [not included in comparison] +Created By [not included in comparison] Type: MANAGED Provider: parquet Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t2 @@ -161,7 +164,7 @@ struct<> -- !query 13 output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input '' expecting 'LIKE'(line 1, pos 19) +mismatched input '' expecting {'FROM', 'IN', 'LIKE'}(line 1, pos 19) == SQL == SHOW TABLE EXTENDED @@ -184,7 +187,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'PARTITION' expecting 'LIKE'(line 1, pos 20) +mismatched input 'PARTITION' expecting {'FROM', 'IN', 'LIKE'}(line 1, pos 20) == SQL == SHOW TABLE EXTENDED PARTITION(c='Us', d=1) diff --git a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out index 05c3a083ee3b..71d6e120e894 100644 --- a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out @@ -19,7 +19,7 @@ struct<> -- !query 2 -CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet +CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json -- !query 2 schema struct<> -- !query 2 output @@ -35,7 +35,7 @@ struct<> -- !query 4 -CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet +CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json -- !query 4 schema struct<> -- !query 4 output diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 732b11050f46..e035505f15d2 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query 0 @@ -122,3 +122,19 @@ struct<> -- !query 12 output org.apache.spark.sql.AnalysisException Function string accepts only one argument; line 1 pos 7 + + +-- !query 13 +CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st) +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value") +-- !query 14 schema +struct +-- !query 14 output +gamma 1 diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 6961e9b65922..2d9b3d7d2ca3 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 2 +-- Number of queries: 12 -- !query 0 @@ -18,3 +18,103 @@ struct<> -- !query 1 output org.apache.spark.sql.AnalysisException requirement failed: format_string() should take at least 1 argument; line 1 pos 7 + + +-- !query 2 +select 'a' || 'b' || 'c' +-- !query 2 schema +struct +-- !query 2 output +abc + + +-- !query 3 +EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col +FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) +-- !query 3 schema +struct +-- !query 3 output +== Parsed Logical Plan == +'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] ++- 'SubqueryAlias __auto_generated_subquery_name + +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] + +- 'UnresolvedTableValuedFunction range, [10] + +== Analyzed Logical Plan == +col: string +Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] ++- SubqueryAlias __auto_generated_subquery_name + +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] + +- Range (0, 10, step=1, splits=None) + +== Optimized Logical Plan == +Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] ++- Range (0, 10, step=1, splits=None) + +== Physical Plan == +*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] ++- *Range (0, 10, step=1, splits=2) + + +-- !query 4 +select replace('abc', 'b', '123') +-- !query 4 schema +struct +-- !query 4 output +a123c + + +-- !query 5 +select replace('abc', 'b') +-- !query 5 schema +struct +-- !query 5 output +ac + + +-- !query 6 +select length(uuid()), (uuid() <> uuid()) +-- !query 6 schema +struct +-- !query 6 output +36 true + + +-- !query 7 +select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null) +-- !query 7 schema +struct +-- !query 7 output +4 NULL NULL + + +-- !query 8 +select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null) +-- !query 8 schema +struct +-- !query 8 output +ab abcd ab NULL + + +-- !query 9 +select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') +-- !query 9 schema +struct +-- !query 9 output +NULL NULL + + +-- !query 10 +select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) +-- !query 10 schema +struct +-- !query 10 output +cd abcd cd NULL + + +-- !query 11 +select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') +-- !query 11 schema +struct +-- !query 11 output +NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out index 3e32f4619546..1da33bc736f0 100644 --- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 9 -- !query 0 @@ -58,3 +58,33 @@ struct> 1 {"AA":"1","C":"gamma","D":"delta"} 2 {"AA":"2","C":"epsilon","D":"eta"} 3 {"AA":"3","C":"theta","D":"iota"} + + +-- !query 6 +SELECT ID, STRUCT(ST.*).C NST FROM tbl_x +-- !query 6 schema +struct +-- !query 6 output +1 gamma +2 epsilon +3 theta + + +-- !query 7 +SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x +-- !query 7 schema +struct +-- !query 7 output +1 delta +2 eta +3 iota + + +-- !query 8 +SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x +-- !query 8 schema +struct +-- !query 8 output +1 delta +2 eta +3 iota diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out new file mode 100644 index 000000000000..70aeb9373f3c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -0,0 +1,104 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 4 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 5 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 1. +#columns in right hand side: 2. +Left side columns: +[t1.`t1a`]. +Right side columns: +[t2.`t2a`, t2.`t2b`].; + + +-- !query 6 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2. +#columns in right hand side: 1. +Left side columns: +[t1.`t1a`, t1.`t1b`]. +Right side columns: +[t2.`t2a`].; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out new file mode 100644 index 000000000000..50370df34916 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out @@ -0,0 +1,50 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT * FROM (SELECT * FROM testData) AS t WHERE key = 1 +-- !query 0 schema +struct +-- !query 0 output +1 1 + + +-- !query 1 +FROM (SELECT * FROM testData WHERE key = 1) AS t SELECT * +-- !query 1 schema +struct +-- !query 1 output +1 1 + + +-- !query 2 +SELECT * FROM (SELECT * FROM testData) t WHERE key = 1 +-- !query 2 schema +struct +-- !query 2 output +1 1 + + +-- !query 3 +FROM (SELECT * FROM testData WHERE key = 1) t SELECT * +-- !query 3 schema +struct +-- !query 3 output +1 1 + + +-- !query 4 +SELECT * FROM (SELECT * FROM testData) WHERE key = 1 +-- !query 4 schema +struct +-- !query 4 output +1 1 + + +-- !query 5 +FROM (SELECT * FROM testData WHERE key = 1) SELECT * +-- !query 5 schema +struct +-- !query 5 output +1 1 diff --git a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out new file mode 100644 index 000000000000..1a2bd5ea91cd --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out @@ -0,0 +1,97 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES (1, 1), (1, 2), (2, 1) AS testData(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT * FROM testData AS t(col1, col2) WHERE col1 = 1 +-- !query 1 schema +struct +-- !query 1 output +1 1 +1 2 + + +-- !query 2 +SELECT * FROM testData AS t(col1, col2) WHERE col1 = 2 +-- !query 2 schema +struct +-- !query 2 output +2 1 + + +-- !query 3 +SELECT col1 AS k, SUM(col2) FROM testData AS t(col1, col2) GROUP BY k +-- !query 3 schema +struct +-- !query 3 output +1 3 +2 1 + + +-- !query 4 +SELECT * FROM testData AS t(col1, col2, col3) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Number of column aliases does not match number of columns. Number of column aliases: 3; number of columns: 2.; line 1 pos 14 + + +-- !query 5 +SELECT * FROM testData AS t(col1) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Number of column aliases does not match number of columns. Number of column aliases: 1; number of columns: 2.; line 1 pos 14 + + +-- !query 6 +SELECT a AS col1, b AS col2 FROM testData AS t(c, d) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '`a`' given input columns: [c, d]; line 1 pos 7 + + +-- !query 7 +SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) +-- !query 7 schema +struct +-- !query 7 output +1 1 + + +-- !query 8 +CREATE OR REPLACE TEMPORARY VIEW src1 AS SELECT * FROM VALUES (1, "a"), (2, "b"), (3, "c") AS src1(id, v1) +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +CREATE OR REPLACE TEMPORARY VIEW src2 AS SELECT * FROM VALUES (2, 1.0), (3, 3.2), (1, 8.5) AS src2(id, v2) +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +SELECT * FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d) +-- !query 10 schema +struct +-- !query 10 output +1 a 1 8.5 +2 b 2 1 +3 c 3 3.2 diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index e2ee970d35f6..a8bc6faf1126 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 10 -- !query 0 @@ -103,3 +103,33 @@ struct -- !query 8 output == Physical Plan == *Range (0, 2, step=1, splits=2) + + +-- !query 9 +EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3) +-- !query 9 schema +struct +-- !query 9 output +== Parsed Logical Plan == +'Project [*] ++- 'Join Cross + :- 'UnresolvedTableValuedFunction range, [3] + +- 'UnresolvedTableValuedFunction range, [3] + +== Analyzed Logical Plan == +id: bigint, id: bigint +Project [id#xL, id#xL] ++- Join Cross + :- Range (0, 3, step=1, splits=None) + +- Range (0, 3, step=1, splits=None) + +== Optimized Logical Plan == +Join Cross +:- Range (0, 3, step=1, splits=None) ++- Range (0, 3, step=1, splits=None) + +== Physical Plan == +BroadcastNestedLoopJoin BuildRight, Cross +:- *Range (0, 3, step=1, splits=2) ++- BroadcastExchange IdentityBroadcastMode + +- *Range (0, 3, step=1, splits=2) diff --git a/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out new file mode 100644 index 000000000000..35f3931736b8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out @@ -0,0 +1,62 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT) +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.catalyst.parser.ParseException + +Sampling fraction (-0.01) must be on interval [0, 1](line 1, pos 24) + +== SQL == +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT) +------------------------^^^ + + +-- !query 4 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.parser.ParseException + +Sampling fraction (1.01) must be on interval [0, 1](line 1, pos 24) + +== SQL == +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT) +------------------------^^^ + + +-- !query 5 +DROP DATABASE mydb1 CASCADE +-- !query 5 schema +struct<> +-- !query 5 output + diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out new file mode 100644 index 000000000000..4815a578b102 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -0,0 +1,54 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(1), (2), (3), (4) +as t1(int_col1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg' +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT default.myDoubleAvg(int_col1) as my_avg from t1 +-- !query 2 schema +struct +-- !query 2 output +102.5 + + +-- !query 3 +SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1 +-- !query 3 schema +struct<> +-- !query 3 output +java.lang.AssertionError +assertion failed: Incorrect number of children + + +-- !query 4 +CREATE FUNCTION udaf1 AS 'test.non.existent.udaf' +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +SELECT default.udaf1(int_col1) as udaf1 from t1 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out new file mode 100644 index 000000000000..73ad27e5bf8c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -0,0 +1,357 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 22 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(null, 1L, 1.0D, date("2017-08-01"), timestamp(1501545600), "a"), +(1, 1L, 1.0D, date("2017-08-01"), timestamp(1501545600), "a"), +(1, 2L, 2.5D, date("2017-08-02"), timestamp(1502000000), "a"), +(2, 2147483650L, 100.001D, date("2020-12-31"), timestamp(1609372800), "a"), +(1, null, 1.0D, date("2017-08-01"), timestamp(1501545600), "b"), +(2, 3L, 3.3D, date("2017-08-03"), timestamp(1503000000), "b"), +(3, 2147483650L, 100.001D, date("2020-12-31"), timestamp(1609372800), "b"), +(null, null, null, null, null, null), +(3, 1L, 1.0D, date("2017-08-01"), timestamp(1501545600), null) +AS testData(val, val_long, val_double, val_date, val_timestamp, cate) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData +ORDER BY cate, val +-- !query 1 schema +struct +-- !query 1 output +NULL NULL 0 +3 NULL 1 +NULL a 0 +1 a 1 +1 a 1 +2 a 1 +1 b 1 +2 b 1 +3 b 1 + + +-- !query 2 +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val +ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 2 schema +struct +-- !query 2 output +NULL NULL 3 +3 NULL 3 +NULL a 1 +1 a 2 +1 a 4 +2 a 4 +1 b 3 +2 b 6 +3 b 6 + + +-- !query 3 +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType does not match the expected data type 'IntegerType'.; line 1 pos 41 + + +-- !query 4 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData +ORDER BY cate, val +-- !query 4 schema +struct +-- !query 4 output +NULL NULL 0 +3 NULL 1 +NULL a 0 +1 a 2 +1 a 2 +2 a 3 +1 b 1 +2 b 2 +3 b 2 + + +-- !query 5 +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 5 schema +struct +-- !query 5 output +NULL NULL NULL +3 NULL 3 +NULL a NULL +1 a 4 +1 a 4 +2 a 2 +1 b 3 +2 b 5 +3 b 3 + + +-- !query 6 +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long +-- !query 6 schema +struct +-- !query 6 output +NULL NULL NULL +1 NULL 1 +1 a 4 +1 a 4 +2 a 2147483652 +2147483650 a 2147483650 +NULL b NULL +3 b 2147483653 +2147483650 b 2147483650 + + +-- !query 7 +SELECT val_double, cate, sum(val_double) OVER(PARTITION BY cate ORDER BY val_double +RANGE BETWEEN CURRENT ROW AND 2.5 FOLLOWING) FROM testData ORDER BY cate, val_double +-- !query 7 schema +struct +-- !query 7 output +NULL NULL NULL +1.0 NULL 1.0 +1.0 a 4.5 +1.0 a 4.5 +2.5 a 2.5 +100.001 a 100.001 +1.0 b 4.3 +3.3 b 3.3 +100.001 b 100.001 + + +-- !query 8 +SELECT val_date, cate, max(val_date) OVER(PARTITION BY cate ORDER BY val_date +RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM testData ORDER BY cate, val_date +-- !query 8 schema +struct +-- !query 8 output +NULL NULL NULL +2017-08-01 NULL 2017-08-01 +2017-08-01 a 2017-08-02 +2017-08-01 a 2017-08-02 +2017-08-02 a 2017-08-02 +2020-12-31 a 2020-12-31 +2017-08-01 b 2017-08-03 +2017-08-03 b 2017-08-03 +2020-12-31 b 2020-12-31 + + +-- !query 9 +SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp +RANGE BETWEEN CURRENT ROW AND interval 23 days 4 hours FOLLOWING) FROM testData +ORDER BY cate, val_timestamp +-- !query 9 schema +struct +-- !query 9 output +NULL NULL NULL +2017-07-31 17:00:00 NULL 1.5015456E9 +2017-07-31 17:00:00 a 1.5016970666666667E9 +2017-07-31 17:00:00 a 1.5016970666666667E9 +2017-08-05 23:13:20 a 1.502E9 +2020-12-30 16:00:00 a 1.6093728E9 +2017-07-31 17:00:00 b 1.5022728E9 +2017-08-17 13:00:00 b 1.503E9 +2020-12-30 16:00:00 b 1.6093728E9 + + +-- !query 10 +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 10 schema +struct +-- !query 10 output +NULL NULL NULL +3 NULL 3 +NULL a NULL +1 a 2 +1 a 2 +2 a 4 +1 b 1 +2 b 3 +3 b 5 + + +-- !query 11 +SELECT val, cate, count(val) OVER(PARTITION BY cate +ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve 'ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING' due to data type mismatch: Window frame upper bound '1' does not followes the lower bound 'unboundedfollowing$()'.; line 1 pos 33 + + +-- !query 12 +SELECT val, cate, count(val) OVER(PARTITION BY cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame cannot be used in an unordered window specification.; line 1 pos 33 + + +-- !query 13 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val, cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY testdata.`val` ASC NULLS FIRST, testdata.`cate` ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame with value boundaries cannot be used in a window specification with multiple order by expressions: val#x ASC NULLS FIRST,cate#x ASC NULLS FIRST; line 1 pos 33 + + +-- !query 14 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY current_timestamp +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_timestamp() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'TimestampType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 33 + + +-- !query 15 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING) FROM testData ORDER BY cate, val +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +cannot resolve 'RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING' due to data type mismatch: The lower bound of a window frame must be less than or equal to the upper bound; line 1 pos 33 + + +-- !query 16 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.catalyst.parser.ParseException + +Frame bound value must be a literal.(line 2, pos 30) + +== SQL == +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val +------------------------------^^^ + + +-- !query 17 +SELECT val, cate, +max(val) OVER w AS max, +min(val) OVER w AS min, +min(val) OVER w AS min, +count(val) OVER w AS count, +sum(val) OVER w AS sum, +avg(val) OVER w AS avg, +stddev(val) OVER w AS stddev, +first_value(val) OVER w AS first_value, +first_value(val, true) OVER w AS first_value_ignore_null, +first_value(val, false) OVER w AS first_value_contain_null, +last_value(val) OVER w AS last_value, +last_value(val, true) OVER w AS last_value_ignore_null, +last_value(val, false) OVER w AS last_value_contain_null, +rank() OVER w AS rank, +dense_rank() OVER w AS dense_rank, +cume_dist() OVER w AS cume_dist, +percent_rank() OVER w AS percent_rank, +ntile(2) OVER w AS ntile, +row_number() OVER w AS row_number, +var_pop(val) OVER w AS var_pop, +var_samp(val) OVER w AS var_samp, +approx_count_distinct(val) OVER w AS approx_count_distinct +FROM testData +WINDOW w AS (PARTITION BY cate ORDER BY val) +ORDER BY cate, val +-- !query 17 schema +struct +-- !query 17 output +NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 +3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 +2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 +1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 +3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 + + +-- !query 18 +SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val +-- !query 18 schema +struct +-- !query 18 output +NULL NULL NULL +3 NULL NULL +NULL a NULL +1 a NULL +1 a NULL +2 a NULL +1 b NULL +2 b NULL +3 b NULL + + +-- !query 19 +SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table; + + +-- !query 20 +SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val +-- !query 20 schema +struct +-- !query 20 output +NULL NULL 13 1.8571428571428572 +3 NULL 13 1.8571428571428572 +NULL a 13 1.8571428571428572 +1 a 13 1.8571428571428572 +1 a 13 1.8571428571428572 +2 a 13 1.8571428571428572 +1 b 13 1.8571428571428572 +2 b 13 1.8571428571428572 +3 b 13 1.8571428571428572 + + +-- !query 21 +SELECT val, cate, +first_value(false) OVER w AS first_value, +first_value(true, true) OVER w AS first_value_ignore_null, +first_value(false, false) OVER w AS first_value_contain_null, +last_value(false) OVER w AS last_value, +last_value(true, true) OVER w AS last_value_ignore_null, +last_value(false, false) OVER w AS last_value_contain_null +FROM testData +WINDOW w AS () +ORDER BY cate, val +-- !query 21 schema +struct +-- !query 21 output +NULL NULL false true false false true false +3 NULL false true false false true false +NULL a false true false false true false +1 a false true false false true false +1 a false true false false true false +2 a false true false false true false +1 b false true false false true false +2 b false true false false true false +3 b false true false false true false diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql new file mode 100755 index 000000000000..79dd3d516e8c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql @@ -0,0 +1,70 @@ +-- start query 10 in stream 0 using template query10.tpl +with +v1 as ( + select + ws_bill_customer_sk as customer_sk + from web_sales, + date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 + union all + select + cs_ship_customer_sk as customer_sk + from catalog_sales, + date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 +), +v2 as ( + select + ss_customer_sk as customer_sk + from store_sales, + date_dim + where ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 +) +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from customer c +join customer_address ca on (c.c_current_addr_sk = ca.ca_address_sk) +join customer_demographics on (cd_demo_sk = c.c_current_cdemo_sk) +left semi join v1 on (v1.customer_sk = c.c_customer_sk) +left semi join v2 on (v2.customer_sk = c.c_customer_sk) +where + ca_county in ('Walker County','Richland County','Gaines County','Douglas County','Dona Ana County') +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 +-- end query 10 in stream 0 using template query10.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql new file mode 100755 index 000000000000..179982776291 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql @@ -0,0 +1,38 @@ +-- start query 19 in stream 0 using template query19.tpl +select + i_brand_id brand_id, + i_brand brand, + i_manufact_id, + i_manufact, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item, + customer, + customer_address, + store +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 7 + and d_moy = 11 + and d_year = 1999 + and ss_customer_sk = c_customer_sk + and c_current_addr_sk = ca_address_sk + and substr(ca_zip, 1, 5) <> substr(s_zip, 1, 5) + and ss_store_sk = s_store_sk + and ss_sold_date_sk between 2451484 and 2451513 -- partition key filter +group by + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +order by + ext_price desc, + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +limit 100 +-- end query 19 in stream 0 using template query19.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql new file mode 100755 index 000000000000..dedbc62a2ab2 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql @@ -0,0 +1,43 @@ +-- start query 27 in stream 0 using template query27.tpl + with results as + (select i_item_id, + s_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + --0 as g_state, + --avg(ss_quantity) agg1, + --avg(ss_list_price) agg2, + --avg(ss_coupon_amt) agg3, + --avg(ss_sales_price) agg4 + from store_sales, customer_demographics, date_dim, store, item + where ss_sold_date_sk = d_date_sk and + ss_sold_date_sk between 2451545 and 2451910 and + ss_item_sk = i_item_sk and + ss_store_sk = s_store_sk and + ss_cdemo_sk = cd_demo_sk and + cd_gender = 'F' and + cd_marital_status = 'D' and + cd_education_status = 'Primary' and + d_year = 2000 and + s_state in ('TN','AL', 'SD', 'SD', 'SD', 'SD') + --group by i_item_id, s_state + ) + + select i_item_id, + s_state, g_state, agg1, agg2, agg3, agg4 + from ( + select i_item_id, s_state, 0 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, avg(agg4) agg4 from results + group by i_item_id, s_state + union all + select i_item_id, NULL AS s_state, 1 AS g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, + avg(agg4) agg4 from results + group by i_item_id + union all + select NULL AS i_item_id, NULL as s_state, 1 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, + avg(agg4) agg4 from results + ) foo + order by i_item_id, s_state + limit 100 +-- end query 27 in stream 0 using template query27.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql new file mode 100755 index 000000000000..35b0a20f80a4 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql @@ -0,0 +1,228 @@ +-- start query 3 in stream 0 using template query3.tpl +select + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_net_profit) sum_agg +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manufact_id = 436 + and dt.d_moy = 12 + -- partition key filters + and ( +ss_sold_date_sk between 2415355 and 2415385 +or ss_sold_date_sk between 2415720 and 2415750 +or ss_sold_date_sk between 2416085 and 2416115 +or ss_sold_date_sk between 2416450 and 2416480 +or ss_sold_date_sk between 2416816 and 2416846 +or ss_sold_date_sk between 2417181 and 2417211 +or ss_sold_date_sk between 2417546 and 2417576 +or ss_sold_date_sk between 2417911 and 2417941 +or ss_sold_date_sk between 2418277 and 2418307 +or ss_sold_date_sk between 2418642 and 2418672 +or ss_sold_date_sk between 2419007 and 2419037 +or ss_sold_date_sk between 2419372 and 2419402 +or ss_sold_date_sk between 2419738 and 2419768 +or ss_sold_date_sk between 2420103 and 2420133 +or ss_sold_date_sk between 2420468 and 2420498 +or ss_sold_date_sk between 2420833 and 2420863 +or ss_sold_date_sk between 2421199 and 2421229 +or ss_sold_date_sk between 2421564 and 2421594 +or ss_sold_date_sk between 2421929 and 2421959 +or ss_sold_date_sk between 2422294 and 2422324 +or ss_sold_date_sk between 2422660 and 2422690 +or ss_sold_date_sk between 2423025 and 2423055 +or ss_sold_date_sk between 2423390 and 2423420 +or ss_sold_date_sk between 2423755 and 2423785 +or ss_sold_date_sk between 2424121 and 2424151 +or ss_sold_date_sk between 2424486 and 2424516 +or ss_sold_date_sk between 2424851 and 2424881 +or ss_sold_date_sk between 2425216 and 2425246 +or ss_sold_date_sk between 2425582 and 2425612 +or ss_sold_date_sk between 2425947 and 2425977 +or ss_sold_date_sk between 2426312 and 2426342 +or ss_sold_date_sk between 2426677 and 2426707 +or ss_sold_date_sk between 2427043 and 2427073 +or ss_sold_date_sk between 2427408 and 2427438 +or ss_sold_date_sk between 2427773 and 2427803 +or ss_sold_date_sk between 2428138 and 2428168 +or ss_sold_date_sk between 2428504 and 2428534 +or ss_sold_date_sk between 2428869 and 2428899 +or ss_sold_date_sk between 2429234 and 2429264 +or ss_sold_date_sk between 2429599 and 2429629 +or ss_sold_date_sk between 2429965 and 2429995 +or ss_sold_date_sk between 2430330 and 2430360 +or ss_sold_date_sk between 2430695 and 2430725 +or ss_sold_date_sk between 2431060 and 2431090 +or ss_sold_date_sk between 2431426 and 2431456 +or ss_sold_date_sk between 2431791 and 2431821 +or ss_sold_date_sk between 2432156 and 2432186 +or ss_sold_date_sk between 2432521 and 2432551 +or ss_sold_date_sk between 2432887 and 2432917 +or ss_sold_date_sk between 2433252 and 2433282 +or ss_sold_date_sk between 2433617 and 2433647 +or ss_sold_date_sk between 2433982 and 2434012 +or ss_sold_date_sk between 2434348 and 2434378 +or ss_sold_date_sk between 2434713 and 2434743 +or ss_sold_date_sk between 2435078 and 2435108 +or ss_sold_date_sk between 2435443 and 2435473 +or ss_sold_date_sk between 2435809 and 2435839 +or ss_sold_date_sk between 2436174 and 2436204 +or ss_sold_date_sk between 2436539 and 2436569 +or ss_sold_date_sk between 2436904 and 2436934 +or ss_sold_date_sk between 2437270 and 2437300 +or ss_sold_date_sk between 2437635 and 2437665 +or ss_sold_date_sk between 2438000 and 2438030 +or ss_sold_date_sk between 2438365 and 2438395 +or ss_sold_date_sk between 2438731 and 2438761 +or ss_sold_date_sk between 2439096 and 2439126 +or ss_sold_date_sk between 2439461 and 2439491 +or ss_sold_date_sk between 2439826 and 2439856 +or ss_sold_date_sk between 2440192 and 2440222 +or ss_sold_date_sk between 2440557 and 2440587 +or ss_sold_date_sk between 2440922 and 2440952 +or ss_sold_date_sk between 2441287 and 2441317 +or ss_sold_date_sk between 2441653 and 2441683 +or ss_sold_date_sk between 2442018 and 2442048 +or ss_sold_date_sk between 2442383 and 2442413 +or ss_sold_date_sk between 2442748 and 2442778 +or ss_sold_date_sk between 2443114 and 2443144 +or ss_sold_date_sk between 2443479 and 2443509 +or ss_sold_date_sk between 2443844 and 2443874 +or ss_sold_date_sk between 2444209 and 2444239 +or ss_sold_date_sk between 2444575 and 2444605 +or ss_sold_date_sk between 2444940 and 2444970 +or ss_sold_date_sk between 2445305 and 2445335 +or ss_sold_date_sk between 2445670 and 2445700 +or ss_sold_date_sk between 2446036 and 2446066 +or ss_sold_date_sk between 2446401 and 2446431 +or ss_sold_date_sk between 2446766 and 2446796 +or ss_sold_date_sk between 2447131 and 2447161 +or ss_sold_date_sk between 2447497 and 2447527 +or ss_sold_date_sk between 2447862 and 2447892 +or ss_sold_date_sk between 2448227 and 2448257 +or ss_sold_date_sk between 2448592 and 2448622 +or ss_sold_date_sk between 2448958 and 2448988 +or ss_sold_date_sk between 2449323 and 2449353 +or ss_sold_date_sk between 2449688 and 2449718 +or ss_sold_date_sk between 2450053 and 2450083 +or ss_sold_date_sk between 2450419 and 2450449 +or ss_sold_date_sk between 2450784 and 2450814 +or ss_sold_date_sk between 2451149 and 2451179 +or ss_sold_date_sk between 2451514 and 2451544 +or ss_sold_date_sk between 2451880 and 2451910 +or ss_sold_date_sk between 2452245 and 2452275 +or ss_sold_date_sk between 2452610 and 2452640 +or ss_sold_date_sk between 2452975 and 2453005 +or ss_sold_date_sk between 2453341 and 2453371 +or ss_sold_date_sk between 2453706 and 2453736 +or ss_sold_date_sk between 2454071 and 2454101 +or ss_sold_date_sk between 2454436 and 2454466 +or ss_sold_date_sk between 2454802 and 2454832 +or ss_sold_date_sk between 2455167 and 2455197 +or ss_sold_date_sk between 2455532 and 2455562 +or ss_sold_date_sk between 2455897 and 2455927 +or ss_sold_date_sk between 2456263 and 2456293 +or ss_sold_date_sk between 2456628 and 2456658 +or ss_sold_date_sk between 2456993 and 2457023 +or ss_sold_date_sk between 2457358 and 2457388 +or ss_sold_date_sk between 2457724 and 2457754 +or ss_sold_date_sk between 2458089 and 2458119 +or ss_sold_date_sk between 2458454 and 2458484 +or ss_sold_date_sk between 2458819 and 2458849 +or ss_sold_date_sk between 2459185 and 2459215 +or ss_sold_date_sk between 2459550 and 2459580 +or ss_sold_date_sk between 2459915 and 2459945 +or ss_sold_date_sk between 2460280 and 2460310 +or ss_sold_date_sk between 2460646 and 2460676 +or ss_sold_date_sk between 2461011 and 2461041 +or ss_sold_date_sk between 2461376 and 2461406 +or ss_sold_date_sk between 2461741 and 2461771 +or ss_sold_date_sk between 2462107 and 2462137 +or ss_sold_date_sk between 2462472 and 2462502 +or ss_sold_date_sk between 2462837 and 2462867 +or ss_sold_date_sk between 2463202 and 2463232 +or ss_sold_date_sk between 2463568 and 2463598 +or ss_sold_date_sk between 2463933 and 2463963 +or ss_sold_date_sk between 2464298 and 2464328 +or ss_sold_date_sk between 2464663 and 2464693 +or ss_sold_date_sk between 2465029 and 2465059 +or ss_sold_date_sk between 2465394 and 2465424 +or ss_sold_date_sk between 2465759 and 2465789 +or ss_sold_date_sk between 2466124 and 2466154 +or ss_sold_date_sk between 2466490 and 2466520 +or ss_sold_date_sk between 2466855 and 2466885 +or ss_sold_date_sk between 2467220 and 2467250 +or ss_sold_date_sk between 2467585 and 2467615 +or ss_sold_date_sk between 2467951 and 2467981 +or ss_sold_date_sk between 2468316 and 2468346 +or ss_sold_date_sk between 2468681 and 2468711 +or ss_sold_date_sk between 2469046 and 2469076 +or ss_sold_date_sk between 2469412 and 2469442 +or ss_sold_date_sk between 2469777 and 2469807 +or ss_sold_date_sk between 2470142 and 2470172 +or ss_sold_date_sk between 2470507 and 2470537 +or ss_sold_date_sk between 2470873 and 2470903 +or ss_sold_date_sk between 2471238 and 2471268 +or ss_sold_date_sk between 2471603 and 2471633 +or ss_sold_date_sk between 2471968 and 2471998 +or ss_sold_date_sk between 2472334 and 2472364 +or ss_sold_date_sk between 2472699 and 2472729 +or ss_sold_date_sk between 2473064 and 2473094 +or ss_sold_date_sk between 2473429 and 2473459 +or ss_sold_date_sk between 2473795 and 2473825 +or ss_sold_date_sk between 2474160 and 2474190 +or ss_sold_date_sk between 2474525 and 2474555 +or ss_sold_date_sk between 2474890 and 2474920 +or ss_sold_date_sk between 2475256 and 2475286 +or ss_sold_date_sk between 2475621 and 2475651 +or ss_sold_date_sk between 2475986 and 2476016 +or ss_sold_date_sk between 2476351 and 2476381 +or ss_sold_date_sk between 2476717 and 2476747 +or ss_sold_date_sk between 2477082 and 2477112 +or ss_sold_date_sk between 2477447 and 2477477 +or ss_sold_date_sk between 2477812 and 2477842 +or ss_sold_date_sk between 2478178 and 2478208 +or ss_sold_date_sk between 2478543 and 2478573 +or ss_sold_date_sk between 2478908 and 2478938 +or ss_sold_date_sk between 2479273 and 2479303 +or ss_sold_date_sk between 2479639 and 2479669 +or ss_sold_date_sk between 2480004 and 2480034 +or ss_sold_date_sk between 2480369 and 2480399 +or ss_sold_date_sk between 2480734 and 2480764 +or ss_sold_date_sk between 2481100 and 2481130 +or ss_sold_date_sk between 2481465 and 2481495 +or ss_sold_date_sk between 2481830 and 2481860 +or ss_sold_date_sk between 2482195 and 2482225 +or ss_sold_date_sk between 2482561 and 2482591 +or ss_sold_date_sk between 2482926 and 2482956 +or ss_sold_date_sk between 2483291 and 2483321 +or ss_sold_date_sk between 2483656 and 2483686 +or ss_sold_date_sk between 2484022 and 2484052 +or ss_sold_date_sk between 2484387 and 2484417 +or ss_sold_date_sk between 2484752 and 2484782 +or ss_sold_date_sk between 2485117 and 2485147 +or ss_sold_date_sk between 2485483 and 2485513 +or ss_sold_date_sk between 2485848 and 2485878 +or ss_sold_date_sk between 2486213 and 2486243 +or ss_sold_date_sk between 2486578 and 2486608 +or ss_sold_date_sk between 2486944 and 2486974 +or ss_sold_date_sk between 2487309 and 2487339 +or ss_sold_date_sk between 2487674 and 2487704 +or ss_sold_date_sk between 2488039 and 2488069 +) +group by + dt.d_year, + item.i_brand, + item.i_brand_id +order by + dt.d_year, + sum_agg desc, + brand_id +limit 100 +-- end query 3 in stream 0 using template query3.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql new file mode 100755 index 000000000000..d11696e5e0c3 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql @@ -0,0 +1,45 @@ +-- start query 34 in stream 0 using template query34.tpl +select + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +from + (select + ss_ticket_number, + ss_customer_sk, + count(*) cnt + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (date_dim.d_dom between 1 and 3 + or date_dim.d_dom between 25 and 28) + and (household_demographics.hd_buy_potential = '>10000' + or household_demographics.hd_buy_potential = 'Unknown') + and household_demographics.hd_vehicle_count > 0 + and (case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end) > 1.2 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_county in ('Saginaw County', 'Sumner County', 'Appanoose County', 'Daviess County', 'Fairfield County', 'Raleigh County', 'Ziebach County', 'Williamson County') + and ss_sold_date_sk between 2450816 and 2451910 -- partition key filter + group by + ss_ticket_number, + ss_customer_sk + ) dn, + customer +where + ss_customer_sk = c_customer_sk + and cnt between 15 and 20 +order by + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag desc +-- end query 34 in stream 0 using template query34.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql new file mode 100755 index 000000000000..b6332a8afbeb --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql @@ -0,0 +1,28 @@ +-- start query 42 in stream 0 using template query42.tpl +select + dt.d_year, + item.i_category_id, + item.i_category, + sum(ss_ext_sales_price) +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manager_id = 1 + and dt.d_moy = 12 + and dt.d_year = 1998 + and ss_sold_date_sk between 2451149 and 2451179 -- partition key filter +group by + dt.d_year, + item.i_category_id, + item.i_category +order by + sum(ss_ext_sales_price) desc, + dt.d_year, + item.i_category_id, + item.i_category +limit 100 +-- end query 42 in stream 0 using template query42.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql new file mode 100755 index 000000000000..cc2040b2fdb7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql @@ -0,0 +1,36 @@ +-- start query 43 in stream 0 using template query43.tpl +select + s_store_name, + s_store_id, + sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, + sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, + sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales, + sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales, + sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales, + sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, + sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales +from + date_dim, + store_sales, + store +where + d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_gmt_offset = -5 + and d_year = 1998 + and ss_sold_date_sk between 2450816 and 2451179 -- partition key filter +group by + s_store_name, + s_store_id +order by + s_store_name, + s_store_id, + sun_sales, + mon_sales, + tue_sales, + wed_sales, + thu_sales, + fri_sales, + sat_sales +limit 100 +-- end query 43 in stream 0 using template query43.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql new file mode 100755 index 000000000000..52b7ba4f4b86 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql @@ -0,0 +1,80 @@ +-- start query 46 in stream 0 using template query46.tpl +select + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics, + customer_address + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and store_sales.ss_addr_sk = customer_address.ca_address_sk + and (household_demographics.hd_dep_count = 5 + or household_demographics.hd_vehicle_count = 3) + and date_dim.d_dow in (6, 0) + and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) + and store.s_city in ('Midway', 'Concord', 'Spring Hill', 'Brownsville', 'Greenville') + -- partition key filter + and ss_sold_date_sk in (2451181, 2451182, 2451188, 2451189, 2451195, 2451196, 2451202, 2451203, 2451209, 2451210, 2451216, 2451217, + 2451223, 2451224, 2451230, 2451231, 2451237, 2451238, 2451244, 2451245, 2451251, 2451252, 2451258, 2451259, + 2451265, 2451266, 2451272, 2451273, 2451279, 2451280, 2451286, 2451287, 2451293, 2451294, 2451300, 2451301, + 2451307, 2451308, 2451314, 2451315, 2451321, 2451322, 2451328, 2451329, 2451335, 2451336, 2451342, 2451343, + 2451349, 2451350, 2451356, 2451357, 2451363, 2451364, 2451370, 2451371, 2451377, 2451378, 2451384, 2451385, + 2451391, 2451392, 2451398, 2451399, 2451405, 2451406, 2451412, 2451413, 2451419, 2451420, 2451426, 2451427, + 2451433, 2451434, 2451440, 2451441, 2451447, 2451448, 2451454, 2451455, 2451461, 2451462, 2451468, 2451469, + 2451475, 2451476, 2451482, 2451483, 2451489, 2451490, 2451496, 2451497, 2451503, 2451504, 2451510, 2451511, + 2451517, 2451518, 2451524, 2451525, 2451531, 2451532, 2451538, 2451539, 2451545, 2451546, 2451552, 2451553, + 2451559, 2451560, 2451566, 2451567, 2451573, 2451574, 2451580, 2451581, 2451587, 2451588, 2451594, 2451595, + 2451601, 2451602, 2451608, 2451609, 2451615, 2451616, 2451622, 2451623, 2451629, 2451630, 2451636, 2451637, + 2451643, 2451644, 2451650, 2451651, 2451657, 2451658, 2451664, 2451665, 2451671, 2451672, 2451678, 2451679, + 2451685, 2451686, 2451692, 2451693, 2451699, 2451700, 2451706, 2451707, 2451713, 2451714, 2451720, 2451721, + 2451727, 2451728, 2451734, 2451735, 2451741, 2451742, 2451748, 2451749, 2451755, 2451756, 2451762, 2451763, + 2451769, 2451770, 2451776, 2451777, 2451783, 2451784, 2451790, 2451791, 2451797, 2451798, 2451804, 2451805, + 2451811, 2451812, 2451818, 2451819, 2451825, 2451826, 2451832, 2451833, 2451839, 2451840, 2451846, 2451847, + 2451853, 2451854, 2451860, 2451861, 2451867, 2451868, 2451874, 2451875, 2451881, 2451882, 2451888, 2451889, + 2451895, 2451896, 2451902, 2451903, 2451909, 2451910, 2451916, 2451917, 2451923, 2451924, 2451930, 2451931, + 2451937, 2451938, 2451944, 2451945, 2451951, 2451952, 2451958, 2451959, 2451965, 2451966, 2451972, 2451973, + 2451979, 2451980, 2451986, 2451987, 2451993, 2451994, 2452000, 2452001, 2452007, 2452008, 2452014, 2452015, + 2452021, 2452022, 2452028, 2452029, 2452035, 2452036, 2452042, 2452043, 2452049, 2452050, 2452056, 2452057, + 2452063, 2452064, 2452070, 2452071, 2452077, 2452078, 2452084, 2452085, 2452091, 2452092, 2452098, 2452099, + 2452105, 2452106, 2452112, 2452113, 2452119, 2452120, 2452126, 2452127, 2452133, 2452134, 2452140, 2452141, + 2452147, 2452148, 2452154, 2452155, 2452161, 2452162, 2452168, 2452169, 2452175, 2452176, 2452182, 2452183, + 2452189, 2452190, 2452196, 2452197, 2452203, 2452204, 2452210, 2452211, 2452217, 2452218, 2452224, 2452225, + 2452231, 2452232, 2452238, 2452239, 2452245, 2452246, 2452252, 2452253, 2452259, 2452260, 2452266, 2452267, + 2452273, 2452274) + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + ca_city + ) dn, + customer, + customer_address current_addr +where + ss_customer_sk = c_customer_sk + and customer.c_current_addr_sk = current_addr.ca_address_sk + and current_addr.ca_city <> bought_city +order by + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number +limit 100 +-- end query 46 in stream 0 using template query46.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql new file mode 100755 index 000000000000..a510eefb13e1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql @@ -0,0 +1,27 @@ +-- start query 52 in stream 0 using template query52.tpl +select + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_ext_sales_price) ext_price +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manager_id = 1 + and dt.d_moy = 12 + and dt.d_year = 1998 + and ss_sold_date_sk between 2451149 and 2451179 -- added for partition pruning +group by + dt.d_year, + item.i_brand, + item.i_brand_id +order by + dt.d_year, + ext_price desc, + brand_id +limit 100 +-- end query 52 in stream 0 using template query52.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql new file mode 100755 index 000000000000..fb7bb7518385 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql @@ -0,0 +1,37 @@ +-- start query 53 in stream 0 using template query53.tpl +select + * +from + (select + i_manufact_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_month_seq in (1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 + 5, 1212 + 6, 1212 + 7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11) + and ((i_category in ('Books', 'Children', 'Electronics') + and i_class in ('personal', 'portable', 'reference', 'self-help') + and i_brand in ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')) + or (i_category in ('Women', 'Music', 'Men') + and i_class in ('accessories', 'classical', 'fragrances', 'pants') + and i_brand in ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1'))) + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + i_manufact_id, + d_qoy + ) tmp1 +where + case when avg_quarterly_sales > 0 then abs (sum_sales - avg_quarterly_sales) / avg_quarterly_sales else null end > 0.1 +order by + avg_quarterly_sales, + sum_sales, + i_manufact_id +limit 100 +-- end query 53 in stream 0 using template query53.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql new file mode 100755 index 000000000000..47b1f0292d90 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql @@ -0,0 +1,24 @@ +-- start query 55 in stream 0 using template query55.tpl +select + i_brand_id brand_id, + i_brand brand, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 48 + and d_moy = 11 + and d_year = 2001 + and ss_sold_date_sk between 2452215 and 2452244 +group by + i_brand, + i_brand_id +order by + ext_price desc, + i_brand_id +limit 100 +-- end query 55 in stream 0 using template query55.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql new file mode 100755 index 000000000000..3d5c4e9d6441 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql @@ -0,0 +1,83 @@ +-- start query 59 in stream 0 using template query59.tpl +with + wss as + (select + d_week_seq, + ss_store_sk, + sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, + sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, + sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales, + sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales, + sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales, + sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, + sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales + from + store_sales, + date_dim + where + d_date_sk = ss_sold_date_sk + group by + d_week_seq, + ss_store_sk + ) +select + s_store_name1, + s_store_id1, + d_week_seq1, + sun_sales1 / sun_sales2, + mon_sales1 / mon_sales2, + tue_sales1 / tue_sales1, + wed_sales1 / wed_sales2, + thu_sales1 / thu_sales2, + fri_sales1 / fri_sales2, + sat_sales1 / sat_sales2 +from + (select + s_store_name s_store_name1, + wss.d_week_seq d_week_seq1, + s_store_id s_store_id1, + sun_sales sun_sales1, + mon_sales mon_sales1, + tue_sales tue_sales1, + wed_sales wed_sales1, + thu_sales thu_sales1, + fri_sales fri_sales1, + sat_sales sat_sales1 + from + wss, + store, + date_dim d + where + d.d_week_seq = wss.d_week_seq + and ss_store_sk = s_store_sk + and d_month_seq between 1185 and 1185 + 11 + ) y, + (select + s_store_name s_store_name2, + wss.d_week_seq d_week_seq2, + s_store_id s_store_id2, + sun_sales sun_sales2, + mon_sales mon_sales2, + tue_sales tue_sales2, + wed_sales wed_sales2, + thu_sales thu_sales2, + fri_sales fri_sales2, + sat_sales sat_sales2 + from + wss, + store, + date_dim d + where + d.d_week_seq = wss.d_week_seq + and ss_store_sk = s_store_sk + and d_month_seq between 1185 + 12 and 1185 + 23 + ) x +where + s_store_id1 = s_store_id2 + and d_week_seq1 = d_week_seq2 - 52 +order by + s_store_name1, + s_store_id1, + d_week_seq1 +limit 100 +-- end query 59 in stream 0 using template query59.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql new file mode 100755 index 000000000000..b71199ab17d0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql @@ -0,0 +1,29 @@ +-- start query 63 in stream 0 using template query63.tpl +select * +from (select i_manager_id + ,sum(ss_sales_price) sum_sales + ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales + from item + ,store_sales + ,date_dim + ,store + where ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_sold_date_sk between 2452123 and 2452487 + and ss_store_sk = s_store_sk + and d_month_seq in (1219,1219+1,1219+2,1219+3,1219+4,1219+5,1219+6,1219+7,1219+8,1219+9,1219+10,1219+11) + and (( i_category in ('Books','Children','Electronics') + and i_class in ('personal','portable','reference','self-help') + and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7', + 'exportiunivamalg #9','scholaramalgamalg #9')) + or( i_category in ('Women','Music','Men') + and i_class in ('accessories','classical','fragrances','pants') + and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1', + 'importoamalg #1'))) +group by i_manager_id, d_moy) tmp1 +where case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 +order by i_manager_id + ,avg_monthly_sales + ,sum_sales +limit 100 +-- end query 63 in stream 0 using template query63.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql new file mode 100755 index 000000000000..7344feeff6a9 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql @@ -0,0 +1,58 @@ +-- start query 65 in stream 0 using template query65.tpl +select + s_store_name, + i_item_desc, + sc.revenue, + i_current_price, + i_wholesale_cost, + i_brand +from + store, + item, + (select + ss_store_sk, + avg(revenue) as ave + from + (select + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) as revenue + from + store_sales, + date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + ss_store_sk, + ss_item_sk + ) sa + group by + ss_store_sk + ) sb, + (select + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) as revenue + from + store_sales, + date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + ss_store_sk, + ss_item_sk + ) sc +where + sb.ss_store_sk = sc.ss_store_sk + and sc.revenue <= 0.1 * sb.ave + and s_store_sk = sc.ss_store_sk + and i_item_sk = sc.ss_item_sk +order by + s_store_name, + i_item_desc +limit 100 +-- end query 65 in stream 0 using template query65.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql new file mode 100755 index 000000000000..94df4b3f57a9 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql @@ -0,0 +1,62 @@ +-- start query 68 in stream 0 using template query68.tpl +-- changed to match exact same partitions in original query +select + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + extended_price, + extended_tax, + list_price +from + (select + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_ext_sales_price) extended_price, + sum(ss_ext_list_price) list_price, + sum(ss_ext_tax) extended_tax + from + store_sales, + date_dim, + store, + household_demographics, + customer_address + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and store_sales.ss_addr_sk = customer_address.ca_address_sk + and date_dim.d_dom between 1 and 2 + and (household_demographics.hd_dep_count = 5 + or household_demographics.hd_vehicle_count = 3) + and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) + and store.s_city in ('Midway', 'Fairview') + -- partition key filter + and ss_sold_date_sk in (2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, 2451331, + 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, 2451485, + 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, 2451666, + 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, 2451820, + 2451850, 2451851, 2451880, 2451881, 2451911, 2451912, 2451942, 2451943, 2451970, 2451971, 2452001, + 2452002, 2452031, 2452032, 2452062, 2452063, 2452092, 2452093, 2452123, 2452124, 2452154, 2452155, + 2452184, 2452185, 2452215, 2452216, 2452245, 2452246) + --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months) + --and d_date between '1999-01-01' and '1999-03-31' + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + ca_city + ) dn, + customer, + customer_address current_addr +where + ss_customer_sk = c_customer_sk + and customer.c_current_addr_sk = current_addr.ca_address_sk + and current_addr.ca_city <> bought_city +order by + c_last_name, + ss_ticket_number +limit 100 +-- end query 68 in stream 0 using template query68.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql new file mode 100755 index 000000000000..c61a2d0d2a8f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql @@ -0,0 +1,31 @@ +-- start query 7 in stream 0 using template query7.tpl +select + i_item_id, + avg(ss_quantity) agg1, + avg(ss_list_price) agg2, + avg(ss_coupon_amt) agg3, + avg(ss_sales_price) agg4 +from + store_sales, + customer_demographics, + date_dim, + item, + promotion +where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_cdemo_sk = cd_demo_sk + and ss_promo_sk = p_promo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and (p_channel_email = 'N' + or p_channel_event = 'N') + and d_year = 1998 + and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter +group by + i_item_id +order by + i_item_id +limit 100 +-- end query 7 in stream 0 using template query7.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql new file mode 100755 index 000000000000..8703910b305a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql @@ -0,0 +1,49 @@ +-- start query 73 in stream 0 using template query73.tpl +select + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +from + (select + ss_ticket_number, + ss_customer_sk, + count(*) cnt + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and date_dim.d_dom between 1 and 2 + and (household_demographics.hd_buy_potential = '>10000' + or household_demographics.hd_buy_potential = 'Unknown') + and household_demographics.hd_vehicle_count > 0 + and case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end > 1 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_county in ('Fairfield County','Ziebach County','Bronx County','Barrow County') + -- partition key filter + and ss_sold_date_sk in (2450815, 2450816, 2450846, 2450847, 2450874, 2450875, 2450905, 2450906, 2450935, 2450936, 2450966, 2450967, + 2450996, 2450997, 2451027, 2451028, 2451058, 2451059, 2451088, 2451089, 2451119, 2451120, 2451149, + 2451150, 2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, + 2451331, 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, + 2451485, 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, + 2451666, 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, + 2451820, 2451850, 2451851, 2451880, 2451881) + --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months) + group by + ss_ticket_number, + ss_customer_sk + ) dj, + customer +where + ss_customer_sk = c_customer_sk + and cnt between 1 and 5 +order by + cnt desc +-- end query 73 in stream 0 using template query73.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql new file mode 100755 index 000000000000..4254310ecd10 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql @@ -0,0 +1,59 @@ +-- start query 79 in stream 0 using template query79.tpl +select + c_last_name, + c_first_name, + substr(s_city, 1, 30), + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + store.s_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (household_demographics.hd_dep_count = 8 + or household_demographics.hd_vehicle_count > 0) + and date_dim.d_dow = 1 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_number_employees between 200 and 295 + and ss_sold_date_sk between 2450819 and 2451904 + -- partition key filter + --and ss_sold_date_sk in (2450819, 2450826, 2450833, 2450840, 2450847, 2450854, 2450861, 2450868, 2450875, 2450882, 2450889, + -- 2450896, 2450903, 2450910, 2450917, 2450924, 2450931, 2450938, 2450945, 2450952, 2450959, 2450966, 2450973, 2450980, 2450987, + -- 2450994, 2451001, 2451008, 2451015, 2451022, 2451029, 2451036, 2451043, 2451050, 2451057, 2451064, 2451071, 2451078, 2451085, + -- 2451092, 2451099, 2451106, 2451113, 2451120, 2451127, 2451134, 2451141, 2451148, 2451155, 2451162, 2451169, 2451176, 2451183, + -- 2451190, 2451197, 2451204, 2451211, 2451218, 2451225, 2451232, 2451239, 2451246, 2451253, 2451260, 2451267, 2451274, 2451281, + -- 2451288, 2451295, 2451302, 2451309, 2451316, 2451323, 2451330, 2451337, 2451344, 2451351, 2451358, 2451365, 2451372, 2451379, + -- 2451386, 2451393, 2451400, 2451407, 2451414, 2451421, 2451428, 2451435, 2451442, 2451449, 2451456, 2451463, 2451470, 2451477, + -- 2451484, 2451491, 2451498, 2451505, 2451512, 2451519, 2451526, 2451533, 2451540, 2451547, 2451554, 2451561, 2451568, 2451575, + -- 2451582, 2451589, 2451596, 2451603, 2451610, 2451617, 2451624, 2451631, 2451638, 2451645, 2451652, 2451659, 2451666, 2451673, + -- 2451680, 2451687, 2451694, 2451701, 2451708, 2451715, 2451722, 2451729, 2451736, 2451743, 2451750, 2451757, 2451764, 2451771, + -- 2451778, 2451785, 2451792, 2451799, 2451806, 2451813, 2451820, 2451827, 2451834, 2451841, 2451848, 2451855, 2451862, 2451869, + -- 2451876, 2451883, 2451890, 2451897, 2451904) + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + store.s_city + ) ms, + customer +where + ss_customer_sk = c_customer_sk +order by + c_last_name, + c_first_name, + substr(s_city, 1, 30), + profit + limit 100 +-- end query 79 in stream 0 using template query79.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql new file mode 100755 index 000000000000..b1d814af5e57 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql @@ -0,0 +1,43 @@ +-- start query 89 in stream 0 using template query89.tpl +select + * +from + (select + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name) avg_monthly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_year in (2000) + and ((i_category in ('Home', 'Books', 'Electronics') + and i_class in ('wallpaper', 'parenting', 'musical')) + or (i_category in ('Shoes', 'Jewelry', 'Men') + and i_class in ('womens', 'birdal', 'pants'))) + and ss_sold_date_sk between 2451545 and 2451910 -- partition key filter + group by + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy + ) tmp1 +where + case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1 +order by + sum_sales - avg_monthly_sales, + s_store_name +limit 100 +-- end query 89 in stream 0 using template query89.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql new file mode 100755 index 000000000000..f53f2f5f9c5b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql @@ -0,0 +1,32 @@ +-- start query 98 in stream 0 using template query98.tpl +select + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) as itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) over (partition by i_class) as revenueratio +from + store_sales, + item, + date_dim +where + ss_item_sk = i_item_sk + and i_category in ('Jewelry', 'Sports', 'Books') + and ss_sold_date_sk = d_date_sk + and ss_sold_date_sk between 2451911 and 2451941 -- partition key filter (1 calendar month) + and d_date between '2001-01-01' and '2001-01-31' +group by + i_item_id, + i_item_desc, + i_category, + i_class, + i_current_price +order by + i_category, + i_class, + i_item_id, + i_item_desc, + revenueratio +--limit 1000; -- added limit +-- end query 98 in stream 0 using template query98.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql new file mode 100755 index 000000000000..bf58b4bb3c5a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql @@ -0,0 +1,14 @@ +select + count(*) as total, + count(ss_sold_date_sk) as not_null_total, + count(distinct ss_sold_date_sk) as unique_days, + max(ss_sold_date_sk) as max_ss_sold_date_sk, + max(ss_sold_time_sk) as max_ss_sold_time_sk, + max(ss_item_sk) as max_ss_item_sk, + max(ss_customer_sk) as max_ss_customer_sk, + max(ss_cdemo_sk) as max_ss_cdemo_sk, + max(ss_hdemo_sk) as max_ss_hdemo_sk, + max(ss_addr_sk) as max_ss_addr_sk, + max(ss_store_sk) as max_ss_store_sk, + max(ss_promo_sk) as max_ss_promo_sk +from store_sales diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 62a75343a094..1aea33766407 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext /** @@ -67,6 +70,30 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { } } + test("percentile_approx, different column types") { + withTempView(table) { + val intSeq = 1 to 1000 + val data: Seq[(java.math.BigDecimal, Date, Timestamp)] = intSeq.map { i => + (new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i), DateTimeUtils.toJavaTimestamp(i)) + } + data.toDF("cdecimal", "cdate", "ctimestamp").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | percentile_approx(cdecimal, array(0.25, 0.5, 0.75D)), + | percentile_approx(cdate, array(0.25, 0.5, 0.75D)), + | percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D)) + |FROM $table + """.stripMargin), + Row( + Seq("250.000000000000000000", "500.000000000000000000", "750.000000000000000000") + .map(i => new java.math.BigDecimal(i)), + Seq(250, 500, 750).map(DateTimeUtils.toJavaDate), + Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong))) + ) + } + } + test("percentile_approx, multiple records with the minimum value in a partition") { withTempView(table) { spark.sparkContext.makeRDD(Seq(1, 1, 2, 1, 1, 3, 1, 1, 4, 1, 1, 5), 4).toDF("col") @@ -88,7 +115,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { val accuracies = Array(1, 10, 100, 1000, 10000) val errors = accuracies.map { accuracy => val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) FROM $table") - val approximatePercentile = df.collect().head.getDouble(0) + val approximatePercentile = df.collect().head.getInt(0) val error = Math.abs(approximatePercentile - expectedPercentile) error } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index e66fe97afad4..1e52445f28fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -313,7 +313,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum - assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes) + assert(cached.stats.sizeInBytes === actualSizeInBytes) } } @@ -420,7 +420,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { - assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected) + assert( + df.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e }.size == expected) } test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { @@ -631,7 +632,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val ds2 = sql( """ - |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |SELECT * FROM (SELECT c1, max(c1) FROM t1 GROUP BY c1) |WHERE |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) |OR @@ -647,7 +648,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTable("t") { withTempPath { path => Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) - sql(s"CREATE TABLE t USING parquet LOCATION '$path'") + sql(s"CREATE TABLE t USING parquet LOCATION '${path.toURI}'") spark.catalog.cacheTable("t") spark.table("t").select($"i").cache() checkAnswer(spark.table("t").select($"i"), Row(1)) @@ -683,20 +684,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t1") Seq(2).toDF("c1").createOrReplaceTempView("t2") - sql( + val sql1 = """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t2) - """.stripMargin).cache() + """.stripMargin + sql(sql1).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |NOT EXISTS (SELECT * FROM t2) - """.stripMargin) + val cachedDs = sql(sql1) assert(getNumInMemoryRelations(cachedDs) == 1) // Additional predicate in the subquery plan should cause a cache miss @@ -717,20 +713,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t2") // Simple correlated predicate in subquery - sql( + val sqlText = """ |SELECT * FROM t1 |WHERE |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) - """.stripMargin).cache() + """.stripMargin + sql(sqlText).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) - """.stripMargin) + val cachedDs = sql(sqlText) assert(getNumInMemoryRelations(cachedDs) == 1) } } @@ -741,22 +732,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("t1") // underlying table t1 is cached as well as the query that refers to it. - val ds = - sql( + val sqlText = """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t1) - """.stripMargin) + """.stripMargin + val ds = sql(sqlText) assert(getNumInMemoryRelations(ds) == 2) - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |NOT EXISTS (SELECT * FROM t1) - """.stripMargin).cache() + val cachedDs = sql(sqlText).cache() assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) } } @@ -769,45 +754,31 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t4") // Nested predicate subquery - sql( + val sql1 = """ |SELECT * FROM t1 |WHERE |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin).cache() + """.stripMargin + sql(sql1).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin) + val cachedDs = sql(sql1) assert(getNumInMemoryRelations(cachedDs) == 1) // Scalar subquery and predicate subquery - sql( + val sql2 = """ - |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |SELECT * FROM (SELECT c1, max(c1) FROM t1 GROUP BY c1) |WHERE |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) |OR |EXISTS (SELECT c1 FROM t3) |OR |c1 IN (SELECT c1 FROM t4) - """.stripMargin).cache() + """.stripMargin + sql(sql2).cache() - val cachedDs2 = - sql( - """ - |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) - |WHERE - |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) - |OR - |EXISTS (SELECT c1 FROM t3) - |OR - |c1 IN (SELECT c1 FROM t4) - """.stripMargin) + val cachedDs2 = sql(sql2) assert(getNumInMemoryRelations(cachedDs2) == 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b0f398dab745..7c45be21961d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -39,6 +39,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) } + private lazy val nullData = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b") + test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -283,23 +286,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("<=>") { - checkAnswer( - testData2.filter($"a" === 1), - testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) - - checkAnswer( - testData2.filter($"a" === $"b"), - testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) - } - - test("=!=") { - val nullData = spark.createDataFrame(sparkContext.parallelize( - Row(1, 1) :: - Row(1, 2) :: - Row(1, null) :: - Row(null, null) :: Nil), - StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - checkAnswer( nullData.filter($"b" <=> 1), Row(1, 1) :: Nil) @@ -321,7 +307,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer( nullData2.filter($"a" <=> null), Row(null) :: Nil) + } + + test("=!=") { + checkAnswer( + nullData.filter($"b" =!= 1), + Row(1, 2) :: Nil) + + checkAnswer(nullData.filter($"b" =!= null), Nil) + checkAnswer( + nullData.filter($"a" =!= $"b"), + Row(1, 2) :: Nil) } test(">") { @@ -533,6 +530,63 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } + test("input_file_name, input_file_block_start, input_file_block_length - more than one source") { + withTempView("tempView1") { + withTable("tab1", "tab2") { + val data = sparkContext.parallelize(0 to 9).toDF("id") + data.write.saveAsTable("tab1") + data.write.saveAsTable("tab2") + data.createOrReplaceTempView("tempView1") + Seq("input_file_name", "input_file_block_start", "input_file_block_length").foreach { f => + val e = intercept[AnalysisException] { + sql(s"SELECT *, $f() FROM tab1 JOIN tab2 ON tab1.id = tab2.id") + }.getMessage + assert(e.contains(s"'$f' does not support more than one source")) + } + + def checkResult( + fromClause: String, + exceptionExpected: Boolean, + numExpectedRows: Int = 0): Unit = { + val stmt = s"SELECT *, input_file_name() FROM ($fromClause)" + if (exceptionExpected) { + val e = intercept[AnalysisException](sql(stmt)).getMessage + assert(e.contains("'input_file_name' does not support more than one source")) + } else { + assert(sql(stmt).count() == numExpectedRows) + } + } + + checkResult( + "SELECT * FROM tab1 UNION ALL SELECT * FROM tab2 UNION ALL SELECT * FROM tab2", + exceptionExpected = false, + numExpectedRows = 30) + + checkResult( + "(SELECT * FROM tempView1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2", + exceptionExpected = false, + numExpectedRows = 20) + + checkResult( + "(SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tempView1", + exceptionExpected = false, + numExpectedRows = 20) + + checkResult( + "(SELECT * FROM tempView1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2", + exceptionExpected = true) + + checkResult( + "(SELECT * FROM tab1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2", + exceptionExpected = true) + + checkResult( + "(SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2", + exceptionExpected = true) + } + } + } + test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala new file mode 100644 index 000000000000..2c1e5db5fd9b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.commons.math3.stat.inference.ChiSquareTest + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + + +class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition") { + // In this test, we run a sort and compute the histogram for partition size post shuffle. + // With a high sample count, the partition size should be more evenly distributed, and has a + // low chi-sq test value. + // Also the whole code path for range partitioning as implemented should be deterministic + // (it uses the partition id as the seed), so this test shouldn't be flaky. + + val numPartitions = 4 + + def computeChiSquareTest(): Double = { + val n = 10000 + // Trigger a sort + val data = spark.range(0, n, 1, 1).sort('id) + .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() + + // Compute histogram for the number of records per partition post sort + val dist = data.groupBy(_._1).map(_._2.length.toLong).toArray + assert(dist.length == 4) + + new ChiSquareTest().chiSquare( + Array.fill(numPartitions) { n.toDouble / numPartitions }, + dist) + } + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) { + // The default chi-sq value should be low + assert(computeChiSquareTest() < 100) + + withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") { + // If we only sample one point, the range boundaries will be pretty bad and the + // chi-sq value would be very high. + assert(computeChiSquareTest() > 1000) + } + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8569c2d76b69..8549eac58ee9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import scala.util.Random + +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -65,9 +69,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( decimalData.groupBy("a").agg(sum("b")), - Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(3.0)), - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(3.0)), - Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0))) + Seq(Row(new java.math.BigDecimal(1), new java.math.BigDecimal(3)), + Row(new java.math.BigDecimal(2), new java.math.BigDecimal(3)), + Row(new java.math.BigDecimal(3), new java.math.BigDecimal(3))) ) val decimalDataWithNulls = spark.sparkContext.parallelize( @@ -80,10 +84,10 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { DecimalData(null, 2) :: Nil).toDF() checkAnswer( decimalDataWithNulls.groupBy("a").agg(sum("b")), - Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.0)), - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.0)), - Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)), - Row(null, new java.math.BigDecimal(2.0))) + Seq(Row(new java.math.BigDecimal(1), new java.math.BigDecimal(1)), + Row(new java.math.BigDecimal(2), new java.math.BigDecimal(1)), + Row(new java.math.BigDecimal(3), new java.math.BigDecimal(3)), + Row(null, new java.math.BigDecimal(2))) ) } @@ -186,6 +190,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") { + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping("CouRse"), grouping("year")), + Row("Java", 2012, 0, 0) :: + Row("Java", 2013, 0, 0) :: + Row("Java", null, 0, 1) :: + Row("dotNET", 2012, 0, 0) :: + Row("dotNET", 2013, 0, 0) :: + Row("dotNET", null, 0, 1) :: + Row(null, 2012, 1, 0) :: + Row(null, 2013, 1, 0) :: + Row(null, null, 1, 1) :: Nil + ) + } + test("rollup overlapping columns") { checkAnswer( testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), @@ -259,19 +279,19 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) + Row(new java.math.BigDecimal(2))) checkAnswer( decimalData.agg(avg('a), sumDistinct('a)), // non-partial - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil) checkAnswer( decimalData.agg(avg('a cast DecimalType(10, 2))), - Row(new java.math.BigDecimal(2.0))) + Row(new java.math.BigDecimal(2))) // non-partial checkAnswer( decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil) } test("null average") { @@ -460,6 +480,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df.select(collect_set($"a"), collect_set($"b")), Seq(Row(Seq(1, 2, 3), Seq(2, 4))) ) + + checkDataset( + df.select(collect_set($"a").as("aSet")).as[Set[Int]], + Set(1, 2, 3)) + checkDataset( + df.select(collect_set($"b").as("bSet")).as[Set[Int]], + Set(2, 4)) + checkDataset( + df.select(collect_set($"a"), collect_set($"b")).as[(Set[Int], Set[Int])], + Seq(Set(1, 2, 3) -> Set(2, 4)): _*) } test("collect functions structs") { @@ -507,12 +537,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } - test("SQL decimal test (used for catching certain demical handling bugs in aggregates)") { + test("SQL decimal test (used for catching certain decimal handling bugs in aggregates)") { checkAnswer( decimalData.groupBy('a cast DecimalType(10, 2)).agg(avg('b cast DecimalType(10, 2))), - Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.5)), - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.5)), - Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(1.5)))) + Seq(Row(new java.math.BigDecimal(1), new java.math.BigDecimal("1.5")), + Row(new java.math.BigDecimal(2), new java.math.BigDecimal("1.5")), + Row(new java.math.BigDecimal(3), new java.math.BigDecimal("1.5")))) } test("SPARK-17616: distinct aggregate combined with a non-partial aggregate") { @@ -533,10 +563,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-17237 remove backticks in a pivot result schema") { val df = Seq((2, 3, 4), (3, 4, 5)).toDF("a", "x", "y") - checkAnswer( - df.groupBy("a").pivot("x").agg(count("y"), avg("y")).na.fill(0), - Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer( + df.groupBy("a").pivot("x").agg(count("y"), avg("y")).na.fill(0), + Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) + ) + } } test("aggregate function in GROUP BY") { @@ -545,4 +577,63 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) } + + private def assertNoExceptions(c: Column): Unit = { + for ((wholeStage, useObjectHashAgg) <- + Seq((true, true), (true, false), (false, true), (false, false))) { + withSQLConf( + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), + (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { + + val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") + + // test case for HashAggregate + val hashAggDF = df.groupBy("x").agg(c, sum("y")) + val hashAggPlan = hashAggDF.queryExecution.executedPlan + if (wholeStage) { + assert(hashAggPlan.find { + case WholeStageCodegenExec(_: HashAggregateExec) => true + case _ => false + }.isDefined) + } else { + assert(hashAggPlan.isInstanceOf[HashAggregateExec]) + } + hashAggDF.collect() + + // test case for ObjectHashAggregate and SortAggregate + val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) + val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan + if (useObjectHashAgg) { + assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) + } else { + assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) + } + objHashAggOrSortAggDF.collect() + } + } + } + + test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + + " before using it") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertNoExceptions) + } + + test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") { + checkAnswer( + testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")), + Seq(Row(3, 4, 6, 7, 9))) + checkAnswer( + testData2.groupBy(lit(3), lit(4)).agg(lit(6), 'b, sum("b")), + Seq(Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6))) + + checkAnswer( + spark.sql("SELECT 3, 4, SUM(b) FROM testData2 GROUP BY 1, 2"), + Seq(Row(3, 4, 9))) + checkAnswer( + spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), + Seq(Row(3, 4, 9))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0e9a2c6cf7de..50e475984f45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -422,7 +422,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { v } withSQLConf( - (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString), + (SQLConf.CODEGEN_FALLBACK.key, codegenFallback.toString), (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { val df = spark.range(0, 4, 1, 4).withColumn("c", c) val rows = df.collect() @@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { rand(Random.nextLong()), randn(Random.nextLong()) ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + + test("SPARK-21281 use string types by default if array and map have no argument") { + val ds = spark.range(1) + var expectedSchema = new StructType() + .add("x", ArrayType(StringType, containsNull = false), nullable = false) + assert(ds.select(array().as("x")).schema == expectedSchema) + expectedSchema = new StructType() + .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) + assert(ds.select(map().as("x")).schema == expectedSchema) + } + + test("SPARK-21281 fails if functions have no argument") { + val df = Seq(1).toDF("a") + + val funcsMustHaveAtLeastOneArg = + ("coalesce", (df: DataFrame) => df.select(coalesce())) :: + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: + ("named_struct", (df: DataFrame) => df.select(struct())) :: + ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: + ("hash", (df: DataFrame) => df.select(hash())) :: + ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil + funcsMustHaveAtLeastOneArg.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least one argument")) + } + + val funcsMustHaveAtLeastTwoArgs = + ("greatest", (df: DataFrame) => df.select(greatest())) :: + ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) :: + ("least", (df: DataFrame) => df.select(least())) :: + ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil + funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least two arguments")) + } + } } object DataFrameFunctionsSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala new file mode 100644 index 000000000000..0dd5bdcba2e4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameHintSuite extends AnalysisTest with SharedSQLContext { + import testImplicits._ + lazy val df = spark.range(10) + + private def check(df: Dataset[_], expected: LogicalPlan) = { + comparePlans( + df.queryExecution.logical, + expected + ) + } + + test("various hint parameters") { + check( + df.hint("hint1"), + UnresolvedHint("hint1", Seq(), + df.logicalPlan + ) + ) + + check( + df.hint("hint1", 1, "a"), + UnresolvedHint("hint1", Seq(1, "a"), df.logicalPlan) + ) + + check( + df.hint("hint1", 1, $"a"), + UnresolvedHint("hint1", Seq(1, $"a"), + df.logicalPlan + ) + ) + + check( + df.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")), + UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), + df.logicalPlan + ) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 63094d1b6122..25e1d93ff092 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 541ffb58e727..aef0d7f3e425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -151,7 +151,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } - test("broadcast join hint") { + test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") @@ -174,6 +174,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } } + test("broadcast join hint using Dataset.hint") { + // make sure a giant join is not broadcastable + val plan1 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong), "id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // now with a hint it should be broadcasted + val plan2 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong).hint("broadcast"), "id") + .queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) + } + test("join - outer join conversion") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") @@ -248,4 +264,14 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val ab = a.join(b, Seq("a"), "fullouter") checkAnswer(ab.join(c, "a"), Row(3, null, 4, 1) :: Nil) } + + test("SPARK-17685: WholeStageCodegenExec throws IndexOutOfBoundsException") { + val df = Seq((1, 1, "1"), (2, 2, "3")).toDF("int", "int2", "str") + val df2 = Seq((1, 1, "1"), (2, 3, "5")).toDF("int", "int2", "str") + val limit = 1310721 + val innerJoin = df.limit(limit).join(df2.limit(limit), Seq("int", "int2"), "inner") + .agg(count($"int")) + checkAnswer(innerJoin, Row(1) :: Nil) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index aa237d0619ac..e6983b6be555 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext - class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -104,105 +104,131 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { test("fill") { val input = createDF() - val fillNumeric = input.na.fill(50.6) - checkAnswer( - fillNumeric, - Row("Bob", 16, 176.5) :: - Row("Alice", 50, 164.3) :: - Row("David", 60, 50.6) :: - Row("Nina", 25, 50.6) :: - Row("Amy", 50, 50.6) :: - Row(null, 50, 50.6) :: Nil) - - // Make sure the columns are properly named. - assert(fillNumeric.columns.toSeq === input.columns.toSeq) - - // string - checkAnswer( - input.na.fill("unknown").select("name"), - Row("Bob") :: Row("Alice") :: Row("David") :: - Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) - assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) - - // fill double with subset columns - checkAnswer( - input.na.fill(50.6, "age" :: Nil).select("name", "age"), - Row("Bob", 16) :: - Row("Alice", 50) :: - Row("David", 60) :: - Row("Nina", 25) :: - Row("Amy", 50) :: - Row(null, 50) :: Nil) - - // fill string with subset columns - checkAnswer( - Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), - Row("test", null)) - - checkAnswer( - Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L)) - .toDF("a", "b").na.fill(0), - Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil - ) - - checkAnswer( - Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), - (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), - Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) - :: Row(0, 0.2) :: Nil - ) - - checkAnswer( - Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), - (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), - Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) - :: Row(0, 0.2f) :: Nil - ) - - checkAnswer( - Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) - .toDF("a", "b").na.fill(2.34), - Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil - ) - - checkAnswer( - Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) - .toDF("a", "b").na.fill(5), - Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil - ) + val boolInput = Seq[(String, java.lang.Boolean)]( + ("Bob", false), + ("Alice", null), + ("Mallory", true), + (null, null) + ).toDF("name", "spy") + + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Nina", 25, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: + Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // boolean + checkAnswer( + boolInput.na.fill(true).select("spy"), + Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil) + assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil).select("name", "age"), + Row("Bob", 16) :: + Row("Alice", 50) :: + Row("David", 60) :: + Row("Nina", 25) :: + Row("Amy", 50) :: + Row(null, 50) :: Nil) + + // fill boolean with subset columns + checkAnswer( + boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"), + Row("Bob", false) :: + Row("Alice", true) :: + Row("Mallory", true) :: + Row(null, true) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + + checkAnswer( + Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L)) + .toDF("a", "b").na.fill(0), + Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), + (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) + :: Row(0, 0.2) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), + (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) + :: Row(0, 0.2f) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(2.34), + Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(5), + Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil + ) + } } test("fill with map") { - val df = Seq[(String, String, java.lang.Integer, java.lang.Long, + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val df = Seq[(String, String, java.lang.Integer, java.lang.Long, java.lang.Float, java.lang.Double, java.lang.Boolean)]( - (null, null, null, null, null, null, null)) - .toDF("stringFieldA", "stringFieldB", "integerField", "longField", - "floatField", "doubleField", "booleanField") - - val fillMap = Map( - "stringFieldA" -> "test", - "integerField" -> 1, - "longField" -> 2L, - "floatField" -> 3.3f, - "doubleField" -> 4.4d, - "booleanField" -> false) - - val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false) - - checkAnswer(df.na.fill(fillMap), expectedRow) - checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version - - // Ensure replacement values are cast to the column data type. - checkAnswer(df.na.fill(Map( - "integerField" -> 1d, - "longField" -> 2d, - "floatField" -> 3d, - "doubleField" -> 4d)), - Row(null, null, 1, 2L, 3f, 4d, null)) - - // Ensure column types do not change. Columns that have null values replaced - // will no longer be flagged as nullable, so do not compare schemas directly. - assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType)) + (null, null, null, null, null, null, null)) + .toDF("stringFieldA", "stringFieldB", "integerField", "longField", + "floatField", "doubleField", "booleanField") + + val fillMap = Map( + "stringFieldA" -> "test", + "integerField" -> 1, + "longField" -> 2L, + "floatField" -> 3.3f, + "doubleField" -> 4.4d, + "booleanField" -> false) + + val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false) + + + checkAnswer(df.na.fill(fillMap), expectedRow) + checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version + + // Ensure replacement values are cast to the column data type. + checkAnswer(df.na.fill(Map( + "integerField" -> 1d, + "longField" -> 2d, + "floatField" -> 3d, + "doubleField" -> 4d)), + Row(null, null, 1, 2L, 3f, 4d, null)) + + // Ensure column types do not change. Columns that have null values replaced + // will no longer be flagged as nullable, so do not compare schemas directly. + assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType)) + } } test("replace") { @@ -236,4 +262,47 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) } + + test("replace with null") { + val input = Seq[(String, java.lang.Double, java.lang.Boolean)]( + ("Bob", 176.5, true), + ("Alice", 164.3, false), + ("David", null, true) + ).toDF("name", "height", "married") + + // Replace String with String and null + checkAnswer( + input.na.replace("name", Map( + "Bob" -> "Bravo", + "Alice" -> null + )), + Row("Bravo", 176.5, true) :: + Row(null, 164.3, false) :: + Row("David", null, true) :: Nil) + + // Replace Double with null + checkAnswer( + input.na.replace("height", Map[Any, Any]( + 164.3 -> null + )), + Row("Bob", 176.5, true) :: + Row("Alice", null, false) :: + Row("David", null, true) :: Nil) + + // Replace Boolean with null + checkAnswer( + input.na.replace("*", Map[Any, Any]( + false -> null + )), + Row("Bob", 176.5, true) :: + Row("Alice", 164.3, null) :: + Row("David", null, true) :: Nil) + + // Replace String with null and then drop rows containing null + checkAnswer( + input.na.replace("name", Map( + "Bob" -> null + )).na.drop("name" :: Nil).select("name"), + Row("Alice") :: Row("David") :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 7b495656b93d..45afbd29d190 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -191,6 +191,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) } } + + test("SPARK-21041 SparkSession.range()'s behavior is inconsistent with SparkContext.range()") { + val start = java.lang.Long.MAX_VALUE - 3 + val end = java.lang.Long.MIN_VALUE + 2 + Seq("false", "true").foreach { value => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value) { + assert(spark.range(start, end, 1).collect.length == 0) + assert(spark.range(start, start, 1).collect.length == 0) + } + } + } } object DataFrameRangeSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index dd118f88e3bb..247c30e2ee65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.Matchers._ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -229,11 +230,9 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val resNaN1 = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) assert(resNaN1.count(_.isNaN) === 0) - assert(resNaN1.count(_ == null) === 0) val resNaN2 = dfNaN.stat.approxQuantile("input2", Array(q1, q2), epsilon) assert(resNaN2.count(_.isNaN) === 0) - assert(resNaN2.count(_ == null) === 0) val resNaN3 = dfNaN.stat.approxQuantile("input3", Array(q1, q2), epsilon) assert(resNaN3.isEmpty) @@ -241,7 +240,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val resNaNAll = dfNaN.stat.approxQuantile(Array("input1", "input2", "input3"), Array(q1, q2), epsilon) assert(resNaNAll.flatten.count(_.isNaN) === 0) - assert(resNaNAll.flatten.count(_ == null) === 0) assert(resNaN1(0) === resNaNAll(0)(0)) assert(resNaN1(1) === resNaNAll(0)(1)) @@ -263,52 +261,56 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("crosstab") { - val rng = new Random() - val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) - val df = data.toDF("a", "b") - val crosstab = df.stat.crosstab("a", "b") - val columnNames = crosstab.schema.fieldNames - assert(columnNames(0) === "a_b") - // reduce by key - val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) - val rows = crosstab.collect() - rows.foreach { row => - val i = row.getString(0).toInt - for (col <- 1 until columnNames.length) { - val j = columnNames(col).toInt - assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") + val crosstab = df.stat.crosstab("a", "b") + val columnNames = crosstab.schema.fieldNames + assert(columnNames(0) === "a_b") + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 until columnNames.length) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } } } } test("special crosstab elements (., '', null, ``)") { - val data = Seq( - ("a", Double.NaN, "ho"), - (null, 2.0, "ho"), - ("a.b", Double.NegativeInfinity, ""), - ("b", Double.PositiveInfinity, "`ha`"), - ("a", 1.0, null) - ) - val df = data.toDF("1", "2", "3") - val ct1 = df.stat.crosstab("1", "2") - // column fields should be 1 + distinct elements of second column - assert(ct1.schema.fields.length === 6) - assert(ct1.collect().length === 4) - val ct2 = df.stat.crosstab("1", "3") - assert(ct2.schema.fields.length === 5) - assert(ct2.schema.fieldNames.contains("ha")) - assert(ct2.collect().length === 4) - val ct3 = df.stat.crosstab("3", "2") - assert(ct3.schema.fields.length === 6) - assert(ct3.schema.fieldNames.contains("NaN")) - assert(ct3.schema.fieldNames.contains("Infinity")) - assert(ct3.schema.fieldNames.contains("-Infinity")) - assert(ct3.collect().length === 4) - val ct4 = df.stat.crosstab("3", "1") - assert(ct4.schema.fields.length === 5) - assert(ct4.schema.fieldNames.contains("null")) - assert(ct4.schema.fieldNames.contains("a.b")) - assert(ct4.collect().length === 4) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val data = Seq( + ("a", Double.NaN, "ho"), + (null, 2.0, "ho"), + ("a.b", Double.NegativeInfinity, ""), + ("b", Double.PositiveInfinity, "`ha`"), + ("a", 1.0, null) + ) + val df = data.toDF("1", "2", "3") + val ct1 = df.stat.crosstab("1", "2") + // column fields should be 1 + distinct elements of second column + assert(ct1.schema.fields.length === 6) + assert(ct1.collect().length === 4) + val ct2 = df.stat.crosstab("1", "3") + assert(ct2.schema.fields.length === 5) + assert(ct2.schema.fieldNames.contains("ha")) + assert(ct2.collect().length === 4) + val ct3 = df.stat.crosstab("3", "2") + assert(ct3.schema.fields.length === 6) + assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("Infinity")) + assert(ct3.schema.fieldNames.contains("-Infinity")) + assert(ct3.collect().length === 4) + val ct4 = df.stat.crosstab("3", "1") + assert(ct4.schema.fields.length === 5) + assert(ct4.schema.fieldNames.contains("null")) + assert(ct4.schema.fieldNames.contains("a.b")) + assert(ct4.collect().length === 4) + } } test("Frequent Items") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ef0de6f6f4ff..0e2f2e5a193e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,11 +28,10 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -112,6 +111,93 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } + test("union by name") { + var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") + val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") + val unionDf = df1.unionByName(df2.unionByName(df3)) + checkAnswer(unionDf, + Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil + ) + + // Check if adjacent unions are combined into a single one + assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) + + // Check failure cases + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains( + "Union can only be performed on tables with the same number of columns, " + + "but the first table has 2 columns and the second table has 3 columns")) + + df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + df2 = Seq((4, 5, 6)).toDF("a", "c", "d") + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) + } + + test("union by name - type coercion") { + var df1 = Seq((1, "a")).toDF("c0", "c1") + var df2 = Seq((3, 1L)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) + + df1 = Seq((1, 1.0)).toDF("c0", "c1") + df2 = Seq((8L, 3.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) + + df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") + df2 = Seq(("a", 4.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) + + df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") + df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") + val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") + checkAnswer(df1.unionByName(df2.unionByName(df3)), + Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil + ) + } + + test("union by name - check case sensitivity") { + def checkCaseSensitiveTest(): Unit = { + val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") + val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") + checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg2 = intercept[AnalysisException] { + checkCaseSensitiveTest() + }.getMessage + assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkCaseSensitiveTest() + } + } + + test("union by name - check name duplication") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var df1 = Seq((1, 1)).toDF(c0, c1) + var df2 = Seq((1, 1)).toDF("c0", "c1") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) + df1 = Seq((1, 1)).toDF("c0", "c1") + df2 = Seq((1, 1)).toDF(c0, c1) + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) + } + } + } + test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) @@ -663,13 +749,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("describe") { - val describeTestData = Seq( - ("Bob", 16, 176), - ("Alice", 32, 164), - ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + private lazy val person2: DataFrame = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + test("describe") { val describeResult = Seq( Row("count", "4", "4", "4"), Row("mean", null, "33.0", "178.0"), @@ -686,32 +772,99 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.describe("name", "age", "height") - assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) - checkAnswer(describeTwoCols, describeResult) - // All aggregate value should have been cast to string - describeTwoCols.collect().foreach { row => - assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) - assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) - } - - val describeAllCols = describeTestData.describe() + val describeAllCols = person2.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeAllCols, describeResult) + // All aggregate value should have been cast to string + describeAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } - val describeOneCol = describeTestData.describe("age") + val describeOneCol = person2.describe("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) - val describeNoCol = describeTestData.select("name").describe() - assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) - checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) + val describeNoCol = person2.select().describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} ) - val emptyDescription = describeTestData.limit(0).describe() + val emptyDescription = person2.limit(0).describe() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) } + test("summary") { + val summaryResult = Seq( + Row("count", "4", "4", "4"), + Row("mean", null, "33.0", "178.0"), + Row("stddev", null, "19.148542155126762", "11.547005383792516"), + Row("min", "Alice", "16", "164"), + Row("25%", null, "24", "176"), + Row("50%", null, "24", "176"), + Row("75%", null, "32", "180"), + Row("max", "David", "60", "192")) + + val emptySummaryResult = Seq( + Row("count", "0", "0", "0"), + Row("mean", null, null, null), + Row("stddev", null, null, null), + Row("min", null, null, null), + Row("25%", null, null, null), + Row("50%", null, null, null), + Row("75%", null, null, null), + Row("max", null, null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val summaryAllCols = person2.summary() + + assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height")) + checkAnswer(summaryAllCols, summaryResult) + // All aggregate value should have been cast to string + summaryAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } + + val summaryOneCol = person2.select("age").summary() + assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age")) + checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} ) + + val summaryNoCol = person2.select().summary() + assert(getSchemaAsSeq(summaryNoCol) === Seq("summary")) + checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} ) + + val emptyDescription = person2.limit(0).summary() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) + checkAnswer(emptyDescription, emptySummaryResult) + } + + test("summary advanced") { + val stats = Array("count", "50.01%", "max", "mean", "min", "25%") + val orderMatters = person2.summary(stats: _*) + assert(orderMatters.collect().map(_.getString(0)) === stats) + + val onlyPercentiles = person2.summary("0.1%", "99.9%") + assert(onlyPercentiles.count() === 2) + + val fooE = intercept[IllegalArgumentException] { + person2.summary("foo") + } + assert(fooE.getMessage === "foo is not a recognised statistic") + + val parseE = intercept[IllegalArgumentException] { + person2.summary("foo%") + } + assert(parseE.getMessage === "Unable to parse foo% as a percentile") + } + test("apply on query results (SPARK-5462)") { val df = testData.sparkSession.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) @@ -1014,7 +1167,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-6899: type should match when using codegen") { - checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) + checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2))) } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -1026,28 +1179,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = spark.read.json(Seq("""{"a.b": {"c": {"d..e": {"f": 1}}}}""").toDS()) - checkAnswer( - df.select(df("`a.b`.c.`d..e`.`f`")), - Row(1) - ) - - val df2 = spark.read.json(Seq("""{"a b": {"c": {"d e": {"f": 1}}}}""").toDS()) - checkAnswer( - df2.select(df2("`a b`.c.d e.f")), - Row(1) - ) - - def checkError(testFun: => Unit): Unit = { - val e = intercept[org.apache.spark.sql.AnalysisException] { - testFun + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val df = spark.read.json(Seq("""{"a.b": {"c": {"d..e": {"f": 1}}}}""").toDS()) + checkAnswer( + df.select(df("`a.b`.c.`d..e`.`f`")), + Row(1) + ) + + val df2 = spark.read.json(Seq("""{"a b": {"c": {"d e": {"f": 1}}}}""").toDS()) + checkAnswer( + df2.select(df2("`a b`.c.d e.f")), + Row(1) + ) + + def checkError(testFun: => Unit): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + testFun + } + assert(e.getMessage.contains("syntax error in attribute name:")) } - assert(e.getMessage.contains("syntax error in attribute name:")) + + checkError(df("`abc.`c`")) + checkError(df("`abc`..d")) + checkError(df("`a`.b.")) + checkError(df("`a.b`.c.`d")) } - checkError(df("`abc.`c`")) - checkError(df("`abc`..d")) - checkError(df("`a`.b.")) - checkError(df("`a.b`.c.`d")) } test("SPARK-7324 dropDuplicates") { @@ -1123,7 +1279,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") .write.format("parquet").save("temp") } - assert(e.getMessage.contains("Duplicate column(s)")) + assert(e.getMessage.contains("Found duplicate column(s) when inserting into")) assert(e.getMessage.contains("column1")) assert(!e.getMessage.contains("column2")) @@ -1133,7 +1289,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .toDF("column1", "column2", "column3", "column1", "column3") .write.format("json").save("temp") } - assert(f.getMessage.contains("Duplicate column(s)")) + assert(f.getMessage.contains("Found duplicate column(s) when inserting into")) assert(f.getMessage.contains("column1")) assert(f.getMessage.contains("column3")) assert(!f.getMessage.contains("column2")) @@ -1177,7 +1333,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - Dataset.ofRows(spark, OneRowRelation).createOrReplaceTempView("one_row") + Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } @@ -1373,7 +1529,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { fail("Should not have back to back Aggregates") } atFirstAgg = true - case e: ShuffleExchange => atFirstAgg = false + case e: ShuffleExchangeExec => atFirstAgg = false case _ => } } @@ -1554,19 +1710,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val plan = join.queryExecution.executedPlan checkAnswer(join, df) assert( - join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + join.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size === 1) assert( join.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 1) val broadcasted = broadcast(join) val join2 = join.join(broadcasted, "id").join(broadcasted, "id") checkAnswer(join2, df) assert( - join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + join2.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) assert( join2.queryExecution.executedPlan .collect { case e: BroadcastExchangeExec => true }.size === 1) assert( - join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 4) + join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size == 4) } } @@ -1775,11 +1931,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-17957: outer join + na.fill") { - val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") - val df2 = Seq((2, 5), (3, 4)).toDF("a", "c") - val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0) - val df3 = Seq((3, 1)).toDF("a", "d") - checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1)) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") + val df2 = Seq((2, 5), (3, 4)).toDF("a", "c") + val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0) + val df3 = Seq((3, 1)).toDF("a", "d") + checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1)) + } } test("SPARK-17123: Performing set operations that combine non-scala native types") { @@ -1813,7 +1971,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") { val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") - checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) + checkAnswer(df, Row(BigDecimal(0)) :: Nil) } test("SPARK-19893: cannot run set operations with map type") { @@ -1844,4 +2002,41 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .filter($"x1".isNotNull || !$"y".isin("a!")) .count } + + testQuietly("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { + val N = 400 + val rows = Seq(Row.fromSeq(Seq.fill(N)("string"))) + val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType))) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) + + val filter = (0 until N) + .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string")) + + withSQLConf(SQLConf.CODEGEN_FALLBACK.key -> "true") { + df.filter(filter).count() + } + + withSQLConf(SQLConf.CODEGEN_FALLBACK.key -> "false") { + val e = intercept[SparkException] { + df.filter(filter).count() + }.getMessage + assert(e.contains("grows beyond 64 KB")) + } + } + + test("SPARK-20897: cached self-join should not fail") { + // force to plan sort merge join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df = Seq(1 -> "a").toDF("i", "j") + val df1 = df.as("t1") + val df2 = df.as("t2") + assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1) + } + } + + test("order-by ordinal.") { + checkAnswer( + testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), + Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 22d5c47a6fb5..6fe356877c26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql -import java.util.TimeZone - import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StringType @@ -29,11 +28,27 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B import testImplicits._ + test("simple tumbling window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + checkAnswer( df.groupBy(window($"time", "10 seconds")) .agg(count("*").as("counts")) @@ -59,14 +74,18 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("tumbling window with multi-column projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Tumbling windows shouldn't require expand") checkAnswer( - df.select(window($"time", "10 seconds"), $"value") - .orderBy($"window.start".asc) - .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"), + df, Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), @@ -104,13 +123,17 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("sliding window projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.nonEmpty, "Sliding windows require expand") checkAnswer( - df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") - .orderBy($"window.start".asc, $"value".desc).select("value"), + df, // 2016-03-27 19:39:27 UTC -> 4 bins // 2016-03-27 19:39:34 UTC -> 3 bins // 2016-03-27 19:39:56 UTC -> 3 bins diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c4910471..ea725af8d1ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval /** * Window function testing for DataFrame API. @@ -150,6 +154,96 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(2.0d), Row(2.0d))) } + test("row between should accept integer values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), + (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), + (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + + def dt(date: String): Date = Date.valueOf(date) + + val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), + (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) + .toDF("key", "value") + checkAnswer( + df2.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)))), + Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), + Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) + ) + } + + test("range between should accept double values as boundary") { + val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), + (3.3D, "2"), (2.02D, "1"), (100.001D, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key") + .rangeBetween(currentRow, lit(2.5D)))), + Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) + ) + } + + test("range between should accept interval values as boundary") { + def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) + + val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), + (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key") + .rangeBetween(currentRow, + lit(CalendarInterval.fromString("interval 23 days 4 hours"))))), + Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), + Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) + ) + } + test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.createOrReplaceTempView("window_table") @@ -423,4 +517,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { df.select(selectList: _*).where($"value" < 2), Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + + test("SPARK-21258: complex object in combination with spilling") { + // Make sure we trigger the spilling path. + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + val sampleSchema = new StructType(). + add("f0", StringType). + add("f1", LongType). + add("f2", ArrayType(new StructType(). + add("f20", StringType))). + add("f3", ArrayType(new StructType(). + add("f30", StringType))) + + val w0 = Window.partitionBy("f0").orderBy("f1") + val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue) + + val c0 = first(struct($"f2", $"f3")).over(w0) as "c0" + val c1 = last(struct($"f2", $"f3")).over(w1) as "c1" + + val input = + """{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]} + |{"f1":1497802179638} + |{"f1":1497802189347} + |{"f1":1497802189593} + |{"f1":1497802189597} + |{"f1":1497802189599} + |{"f1":1497802192103} + |{"f1":1497802193414} + |{"f1":1497802193577} + |{"f1":1497802193709} + |{"f1":1497802202883} + |{"f1":1497802203006} + |{"f1":1497802203743} + |{"f1":1497802203834} + |{"f1":1497802203887} + |{"f1":1497802203893} + |{"f1":1497802203976} + |{"f1":1497820168098} + |""".stripMargin.split("\n").toSeq + + import testImplicits._ + + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 541565344f75..edcdd77908d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql +import scala.collection.immutable.{HashSet => HSet} import scala.collection.immutable.Queue +import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.test.SharedSQLContext @@ -30,8 +32,17 @@ case class ListClass(l: List[Int]) case class QueueClass(q: Queue[Int]) +case class MapClass(m: Map[Int, Int]) + +case class LHMapClass(m: LHMap[Int, Int]) + case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) +case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) + +case class InnerData(name: String, value: Int) +case class NestedData(id: Int, param: Map[String, InnerData]) + package object packageobject { case class PackageClass(value: Int) } @@ -140,7 +151,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("foreachPartition") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.longAccumulator - ds.foreachPartition(_.foreach(acc.add(_))) + ds.foreachPartition((it: Iterator[Int]) => it.foreach(acc.add(_))) assert(acc.value == 6) } @@ -258,9 +269,128 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) } + test("arbitrary maps") { + checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2)) + checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong)) + checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble)) + checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat)) + checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte)) + checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort)) + checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false)) + checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2")) + checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2))) + checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong)) + + checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2)) + checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong)) + checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte)) + checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort)) + checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false)) + checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2")) + checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) + checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) + } + + ignore("SPARK-19104: map and product combinations") { + // Case classes + checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) + checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + + checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2))) + checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + Map(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + LHMap(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + + val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4))) + checkDataset(Seq(complex).toDS(), complex) + checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex)) + checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5)) + checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex)) + checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex)) + checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5)) + checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex)) + + // Tuples + checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(), + LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))) + + // Complex + checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(), + LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) + } + + test("arbitrary sets") { + checkDataset(Seq(Set(1, 2, 3, 4)).toDS(), Set(1, 2, 3, 4)) + checkDataset(Seq(Set(1.toLong, 2.toLong)).toDS(), Set(1.toLong, 2.toLong)) + checkDataset(Seq(Set(1.toDouble, 2.toDouble)).toDS(), Set(1.toDouble, 2.toDouble)) + checkDataset(Seq(Set(1.toFloat, 2.toFloat)).toDS(), Set(1.toFloat, 2.toFloat)) + checkDataset(Seq(Set(1.toByte, 2.toByte)).toDS(), Set(1.toByte, 2.toByte)) + checkDataset(Seq(Set(1.toShort, 2.toShort)).toDS(), Set(1.toShort, 2.toShort)) + checkDataset(Seq(Set(true, false)).toDS(), Set(true, false)) + checkDataset(Seq(Set("test1", "test2")).toDS(), Set("test1", "test2")) + checkDataset(Seq(Set(Tuple1(1), Tuple1(2))).toDS(), Set(Tuple1(1), Tuple1(2))) + + checkDataset(Seq(HSet(1, 2)).toDS(), HSet(1, 2)) + checkDataset(Seq(HSet(1.toLong, 2.toLong)).toDS(), HSet(1.toLong, 2.toLong)) + checkDataset(Seq(HSet(1.toDouble, 2.toDouble)).toDS(), HSet(1.toDouble, 2.toDouble)) + checkDataset(Seq(HSet(1.toFloat, 2.toFloat)).toDS(), HSet(1.toFloat, 2.toFloat)) + checkDataset(Seq(HSet(1.toByte, 2.toByte)).toDS(), HSet(1.toByte, 2.toByte)) + checkDataset(Seq(HSet(1.toShort, 2.toShort)).toDS(), HSet(1.toShort, 2.toShort)) + checkDataset(Seq(HSet(true, false)).toDS(), HSet(true, false)) + checkDataset(Seq(HSet("test1", "test2")).toDS(), HSet("test1", "test2")) + checkDataset(Seq(HSet(Tuple1(1), Tuple1(2))).toDS(), HSet(Tuple1(1), Tuple1(2))) + + checkDataset(Seq(Seq(Some(1), None), Seq(Some(2))).toDF("c").as[Set[Integer]], + Seq(Set[Integer](1, null), Set[Integer](2)): _*) + } + + test("nested sequences") { + checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) + checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) + } + + test("nested maps") { + checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3))) + checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) + } + + test("nested set") { + checkDataset(Seq(Set(HSet(1, 2), HSet(3, 4))).toDS(), Set(HSet(1, 2), HSet(3, 4))) + checkDataset(Seq(HSet(Set(1, 2), Set(3, 4))).toDS(), HSet(Set(1, 2), Set(3, 4))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } + test("SPARK-19104: Lambda variables in ExternalMapToCatalyst should be global") { + val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) + val ds = spark.createDataset(data) + checkDataset(ds, data: _*) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5b5cd28ad0c9..dace6825ee40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,17 +21,29 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) +object TestForTypeAlias { + type TwoInt = (Int, Int) + type ThreeInt = (TwoInt, Int) + type SeqOfTwoInt = Seq[TwoInt] + + def tupleTypeAlias: TwoInt = (1, 1) + def nestedTupleTypeAlias: ThreeInt = ((1, 1), 2) + def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2)) +} + class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -243,6 +255,85 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3))) } + test("REGEX column specification") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + var e = intercept[AnalysisException] { + ds.select(expr("`(_1)?+.+`").as[Int]) + }.getMessage + assert(e.contains("cannot resolve '`(_1)?+.+`'")) + + e = intercept[AnalysisException] { + ds.select(expr("`(_1|_2)`").as[Int]) + }.getMessage + assert(e.contains("cannot resolve '`(_1|_2)`'")) + + e = intercept[AnalysisException] { + ds.select(ds("`(_1)?+.+`")) + }.getMessage + assert(e.contains("Cannot resolve column name \"`(_1)?+.+`\"")) + + e = intercept[AnalysisException] { + ds.select(ds("`(_1|_2)`")) + }.getMessage + assert(e.contains("Cannot resolve column name \"`(_1|_2)`\"")) + } + + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") { + checkDataset( + ds.select(ds.col("_2")).as[Int], + 1, 2, 3) + + checkDataset( + ds.select(ds.colRegex("`(_1)?+.+`")).as[Int], + 1, 2, 3) + + checkDataset( + ds.select(ds("`(_1|_2)`")) + .select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + + checkDataset( + ds.alias("g") + .select(ds("g.`(_1|_2)`")) + .select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + + checkDataset( + ds.select(ds("`(_1)?+.+`")) + .select(expr("_2").as[Int]), + 1, 2, 3) + + checkDataset( + ds.alias("g") + .select(ds("g.`(_1)?+.+`")) + .select(expr("_2").as[Int]), + 1, 2, 3) + + checkDataset( + ds.select(expr("`(_1)?+.+`").as[Int]), + 1, 2, 3) + val m = ds.select(expr("`(_1|_2)`")) + + checkDataset( + ds.select(expr("`(_1|_2)`")) + .select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + + checkDataset( + ds.alias("g") + .select(expr("g.`(_1)?+.+`").as[Int]), + 1, 2, 3) + + checkDataset( + ds.alias("g") + .select(expr("g.`(_1|_2)`")) + .select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + } + } + test("filter") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( @@ -273,13 +364,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("foreachPartition") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.longAccumulator - ds.foreachPartition(_.foreach(v => acc.add(v._2))) + ds.foreachPartition((it: Iterator[(String, Int)]) => it.foreach(v => acc.add(v._2))) assert(acc.value == 6) } test("reduce") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) + assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == (("sum", 6))) } test("joinWith, flat schema") { @@ -320,6 +411,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ((("b", 2), ("b", 2)), ("b", 2))) } + test("joinWith join types") { + val ds1 = Seq(1, 2, 3).toDS().as("a") + val ds2 = Seq(1, 2).toDS().as("b") + + val e1 = intercept[AnalysisException] { + ds1.joinWith(ds2, $"a.value" === $"b.value", "left_semi") + }.getMessage + assert(e1.contains("Invalid join type in joinWith: " + LeftSemi.sql)) + + val e2 = intercept[AnalysisException] { + ds1.joinWith(ds2, $"a.value" === $"b.value", "left_anti") + }.getMessage + assert(e2.contains("Invalid join type in joinWith: " + LeftAnti.sql)) + } + test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupByKey(v => (1, v._2)) @@ -456,6 +562,34 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3, 17, 27, 58, 62) } + test("sample fraction should not be negative with replacement") { + val data = sparkContext.parallelize(1 to 2, 1).toDS() + val errMsg = intercept[IllegalArgumentException] { + data.sample(withReplacement = true, -0.1, 0) + }.getMessage + assert(errMsg.contains("Sampling fraction (-0.1) must be nonnegative with replacement")) + + // Sampling fraction can be greater than 1 with replacement. + checkDataset( + data.sample(withReplacement = true, 1.05, seed = 13), + 1, 2) + } + + test("sample fraction should be on interval [0, 1] without replacement") { + val data = sparkContext.parallelize(1 to 2, 1).toDS() + val errMsg1 = intercept[IllegalArgumentException] { + data.sample(withReplacement = false, -0.1, 0) + }.getMessage() + assert(errMsg1.contains( + "Sampling fraction (-0.1) must be on interval [0, 1] without replacement")) + + val errMsg2 = intercept[IllegalArgumentException] { + data.sample(withReplacement = false, 1.1, 0) + }.getMessage() + assert(errMsg2.contains( + "Sampling fraction (1.1) must be on interval [0, 1] without replacement")) + } + test("SPARK-16686: Dataset.sample with seed results shouldn't depend on downstream usage") { val simpleUdf = udf((n: Int) => { require(n != 1, "simpleUdf shouldn't see id=1!") @@ -676,7 +810,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-14000: case class with tuple type field") { checkDataset( Seq(TupleClass((1, "a"))).toDS(), - TupleClass(1, "a") + TupleClass((1, "a")) ) } @@ -1072,7 +1206,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = cp.groupBy('id % 2).agg(count('id)) agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchange(_, _: RDDScanExec, _) => + case ShuffleExchangeExec(_, _: RDDScanExec, _) => case BroadcastExchangeExec(_, _: RDDScanExec) => }.foreach { _ => fail( @@ -1117,7 +1251,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // instead of Int for avoiding possible overflow. val ds = (0 to 10000).map( i => (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() - val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes + val sizeInBytes = ds.logicalPlan.stats.sizeInBytes // sizeInBytes is 2404280404, before the fix, it overflows to a negative number assert(sizeInBytes > 0) } @@ -1168,6 +1302,45 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) } + + test("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") { + withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { + val data = Seq("\u0020\u0021\u0023", "abc") + val df = data.toDF() + val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'") + val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$")) + val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'") + checkAnswer(rlike1, rlike2) + assert(rlike3.count() == 0) + } + } + + test("SPARK-21538: Attribute resolution inconsistency in Dataset API") { + val df = spark.range(3).withColumnRenamed("id", "x") + val expected = Row(0) :: Row(1) :: Row (2) :: Nil + checkAnswer(df.sort("id"), expected) + checkAnswer(df.sort(col("id")), expected) + checkAnswer(df.sort($"id"), expected) + checkAnswer(df.sort('id), expected) + checkAnswer(df.orderBy("id"), expected) + checkAnswer(df.orderBy(col("id")), expected) + checkAnswer(df.orderBy($"id"), expected) + checkAnswer(df.orderBy('id), expected) + } + + test("SPARK-21567: Dataset should work with type alias") { + checkDataset( + Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)), + ("", (1, 1))) + + checkDataset( + Seq(1).toDS().map(_ => ("", TestForTypeAlias.nestedTupleTypeAlias)), + ("", ((1, 1), 2))) + + checkDataset( + Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)), + ("", Seq((1, 1), (2, 2)))) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 2acda3f00732..3a8694839bb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -387,7 +387,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("to_date(s)"), Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), Row(null))) - // Now with format + // now with format checkAnswer( df.select(to_date(col("t"), "yyyy-MM-dd")), Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), @@ -400,7 +400,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { df.select(to_date(col("s"), "yyyy-MM-dd")), Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), Row(null))) - // now switch format + // now switch format checkAnswer( df.select(to_date(col("s"), "yyyy-dd-MM")), Seq(Row(null), Row(null), Row(Date.valueOf("2014-12-31")))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index b9871afd59e4..6b98209fd49b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -43,6 +43,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"), Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil) + // Null values + checkAnswer(df.selectExpr("stack(3, 1, 1.1, null, 2, null, 'b', null, 3.3, 'c')"), + Row(1, 1.1, null) :: Row(2, null, "b") :: Row(null, 3.3, "c") :: Nil) + // Repeat generation at every input row checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil) @@ -297,7 +301,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { } test("outer generator()") { - spark.sessionState.functionRegistry.registerFunction("empty_gen", _ => EmptyGenerator()) + spark.sessionState.functionRegistry + .createOrReplaceTempFunction("empty_gen", _ => EmptyGenerator()) checkAnswer( sql("select * from values 1, 2 lateral view outer empty_gen() a as b"), Row(1, null) :: Row(2, null) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1a66aa85f5a0..226cc3028b13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.language.existentials -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} +import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.sql.types.StructType class JoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -33,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() def statisticSizeInByte(df: DataFrame): BigInt = { - df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes + df.queryExecution.optimizedPlan.stats.sizeInBytes } test("equi-join is hash-join") { @@ -126,7 +130,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoinExec]) ).foreach(assertJoin) - sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { @@ -141,7 +144,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashJoinExec]) ).foreach(assertJoin) - sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { @@ -198,6 +200,14 @@ class JoinSuite extends QueryTest with SharedSQLContext { Nil) } + test("SPARK-22141: Propagate empty relation before checking Cartesian products") { + Seq("inner", "left", "right", "left_outer", "right_outer", "full_outer").foreach { joinType => + val x = testData2.where($"a" === 2 && !($"a" === 2)).as("x") + val y = testData2.where($"a" === 1 && !($"a" === 1)).as("y") + checkAnswer(x.join(y, Seq.empty, joinType), Nil) + } + } + test("big inner join, 4 matches per row") { val bigData = testData.union(testData).union(testData).union(testData) val bigDataX = bigData.as("x") @@ -216,6 +226,9 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(1, null, 2, 2) :: Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) + checkAnswer( + testData3.as("x").join(testData3.as("y"), $"x.a" > $"y.a"), + Row(2, 2, 1, null) :: Nil) } withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { val e = intercept[Exception] { @@ -473,7 +486,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { ).foreach(assertJoin) } - sql("UNCACHE TABLE testData") } test("cross join with broadcast") { @@ -562,7 +574,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row("2", 3, 2) :: Nil) } - sql("UNCACHE TABLE testData") } test("left semi join") { @@ -604,6 +615,35 @@ class JoinSuite extends QueryTest with SharedSQLContext { } cartesianQueries.foreach(checkCartesianDetection) + + // Check that left_semi, left_anti, existence joins without conditions do not throw + // an exception if cross joins are disabled + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + checkAnswer( + sql("SELECT * FROM testData3 LEFT SEMI JOIN testData2"), + Row(1, null) :: Row (2, 2) :: Nil) + checkAnswer( + sql("SELECT * FROM testData3 LEFT ANTI JOIN testData2"), + Nil) + checkAnswer( + sql( + """ + |SELECT a FROM testData3 + |WHERE + | EXISTS (SELECT * FROM testData) + |OR + | EXISTS (SELECT * FROM testData2)""".stripMargin), + Row(1) :: Row(2) :: Nil) + checkAnswer( + sql( + """ + |SELECT key FROM testData + |WHERE + | key IN (SELECT a FROM testData2) + |OR + | key IN (SELECT a FROM testData3)""".stripMargin), + Row(1) :: Row(2) :: Row(3) :: Nil) + } } test("test SortMergeJoin (without spill)") { @@ -665,7 +705,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("test SortMergeJoin (with spill)") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", - "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") { + "spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "0", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "1") { assertSpilled(sparkContext, "inner join") { checkAnswer( @@ -738,4 +779,82 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } } + + test("outer broadcast hash join should not throw NPE") { + withTempView("v1", "v2") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + Seq(2 -> 2).toDF("x", "y").createTempView("v1") + + spark.createDataFrame( + Seq(Row(1, "a")).asJava, + new StructType().add("i", "int", nullable = false).add("j", "string", nullable = false) + ).createTempView("v2") + + checkAnswer( + sql("select x, y, i, j from v1 left join v2 on x = i and y < length(j)"), + Row(2, 2, null, null) + ) + } + } + } + + test("test SortMergeJoin output ordering") { + val joinQueries = Seq( + "SELECT * FROM testData JOIN testData2 ON key = a", + "SELECT * FROM testData t1 JOIN " + + "testData2 t2 ON t1.key = t2.a JOIN testData3 t3 ON t2.a = t3.a", + "SELECT * FROM testData t1 JOIN " + + "testData2 t2 ON t1.key = t2.a JOIN " + + "testData3 t3 ON t2.a = t3.a JOIN " + + "testData t4 ON t1.key = t4.key") + + def assertJoinOrdering(sqlString: String): Unit = { + val df = sql(sqlString) + val physical = df.queryExecution.sparkPlan + val physicalJoins = physical.collect { + case j: SortMergeJoinExec => j + } + val executed = df.queryExecution.executedPlan + val executedJoins = executed.collect { + case j: SortMergeJoinExec => j + } + // This only applies to the above tested queries, in which a child SortMergeJoin always + // contains the SortOrder required by its parent SortMergeJoin. Thus, SortExec should never + // appear as parent of SortMergeJoin. + executed.foreach { + case s: SortExec => s.foreach { + case j: SortMergeJoinExec => fail( + s"No extra sort should be added since $j already satisfies the required ordering" + ) + case _ => + } + case _ => + } + val joinPairs = physicalJoins.zip(executedJoins) + val numOfJoins = sqlString.split(" ").count(_.toUpperCase == "JOIN") + assert(joinPairs.size == numOfJoins) + + joinPairs.foreach { + case(join1, join2) => + val leftKeys = join1.leftKeys + val rightKeys = join1.rightKeys + val outputOrderingPhysical = join1.outputOrdering + val outputOrderingExecuted = join2.outputOrdering + + // outputOrdering should always contain join keys + assert( + SortOrder.orderingSatisfies( + outputOrderingPhysical, leftKeys.map(SortOrder(_, Ascending)))) + assert( + SortOrder.orderingSatisfies( + outputOrderingPhysical, rightKeys.map(SortOrder(_, Ascending)))) + // outputOrdering should be consistent between physical plan and executed plan + assert(outputOrderingPhysical == outputOrderingExecuted, + s"Operator $join1 did not have the same output ordering in the physical plan as in " + + s"the executed plan.") + } + } + + joinQueries.foreach(assertJoinOrdering) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 69a500c845a7..00d2acc4a1d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, struct, to_json} +import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -156,13 +156,20 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } - test("from_json uses DDL strings for defining a schema") { + test("from_json uses DDL strings for defining a schema - java") { val df = Seq("""{"a": 1, "b": "haa"}""").toDS() checkAnswer( df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())), Row(Row(1, "haa")) :: Nil) } + test("from_json uses DDL strings for defining a schema - scala") { + val df = Seq("""{"a": 1, "b": "haa"}""").toDS() + checkAnswer( + df.select(from_json($"value", "a INT, b STRING", Map[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") @@ -173,10 +180,26 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("to_json - array") { val df = Seq(Tuple1(Tuple1(1) :: Nil)).toDF("a") + val df2 = Seq(Tuple1(Map("a" -> 1) :: Nil)).toDF("a") checkAnswer( df.select(to_json($"a")), Row("""[{"_1":1}]""") :: Nil) + checkAnswer( + df2.select(to_json($"a")), + Row("""[{"a":1}]""") :: Nil) + } + + test("to_json - map") { + val df1 = Seq(Map("a" -> Tuple1(1))).toDF("a") + val df2 = Seq(Map("a" -> 1)).toDF("a") + + checkAnswer( + df1.select(to_json($"a")), + Row("""{"a":{"_1":1}}""") :: Nil) + checkAnswer( + df2.select(to_json($"a")), + Row("""{"a":1}""") :: Nil) } test("to_json with option") { @@ -188,15 +211,33 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("""{"_1":"26/08/2015 18:00"}""") :: Nil) } - test("to_json unsupported type") { + test("to_json - key types of map don't matter") { + // interval type is invalid for converting to JSON. However, the keys of a map are treated + // as strings, so its type doesn't matter. val df = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a") - .select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c")) + .select(struct(map($"a._1".cast(CalendarIntervalType), lit("a")).as("col1")).as("c")) + checkAnswer( + df.select(to_json($"c")), + Row("""{"col1":{"interval -3 months 7 hours":"a"}}""") :: Nil) + } + + test("to_json unsupported type") { + val baseDf = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a") + val df = baseDf.select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c")) val e = intercept[AnalysisException]{ // Unsupported type throws an exception df.select(to_json($"c")).collect() } assert(e.getMessage.contains( "Unable to convert column a of type calendarinterval to JSON.")) + + // interval type is invalid for converting to JSON. We can't use it as value type of a map. + val df2 = baseDf + .select(struct(map(lit("a"), $"a._1".cast(CalendarIntervalType)).as("col1")).as("c")) + val e2 = intercept[AnalysisException] { + df2.select(to_json($"c")).collect() + } + assert(e2.getMessage.contains("Unable to convert column col1 of type calendarinterval to JSON")) } test("roundtrip in to_json and from_json - struct") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 328c5395ec91..c2d08a06569b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + + val bdPi: BigDecimal = BigDecimal(31415925L, 7) + checkAnswer( + sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " + + s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null)) + ) + + checkAnswer( + sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " + + s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) + ) } test("round/bround with data frame from a local Seq of Product") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala index 52c200796ce4..623a1b6f854c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -22,20 +22,22 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.streaming.ProcessingTime +import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} class ProcessingTimeSuite extends SparkFunSuite { test("create") { - assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000) - assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000) - assert(ProcessingTime("1 minute").intervalMs === 60 * 1000) - assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000) - - intercept[IllegalArgumentException] { ProcessingTime(null: String) } - intercept[IllegalArgumentException] { ProcessingTime("") } - intercept[IllegalArgumentException] { ProcessingTime("invalid") } - intercept[IllegalArgumentException] { ProcessingTime("1 month") } - intercept[IllegalArgumentException] { ProcessingTime("1 year") } + def getIntervalMs(trigger: Trigger): Long = trigger.asInstanceOf[ProcessingTime].intervalMs + + assert(getIntervalMs(Trigger.ProcessingTime(10.seconds)) === 10 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) === 10 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime("1 minute")) === 60 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime("interval 1 minute")) === 60 * 1000) + + intercept[IllegalArgumentException] { Trigger.ProcessingTime(null: String) } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("invalid") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 month") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 year") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 7516be315dd2..57b5f5e4ab99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 2b35db411e2a..a1799829932b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -1,26 +1,26 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3ecbf96b4196..93a7777b70b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -106,7 +106,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } } @@ -523,14 +523,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("negative in LIMIT or TABLESAMPLE") { - val expected = "The limit expression must be equal to or greater than 0, but got -1" - var e = intercept[AnalysisException] { - sql("SELECT * FROM testData TABLESAMPLE (-1 rows)") - }.getMessage - assert(e.contains(expected)) - } - test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1227,7 +1219,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-3483 Special chars in column names") { val data = Seq("""{"key?number1": "value1", "key.number2": "value2"}""").toDS() spark.read.json(data).createOrReplaceTempView("records") - sql("SELECT `key?number1`, `key.number2` FROM records") + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + sql("SELECT `key?number1`, `key.number2` FROM records") + } } test("SPARK-3814 Support Bitwise & operator") { @@ -1347,7 +1341,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .json(Seq("""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""").toDS()) .createOrReplaceTempView("t") - checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } } test("SPARK-6583 order by aggregated function") { @@ -1550,10 +1546,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(d))) checkAnswer( df.selectExpr("b * a + b"), - Seq(Row(BigDecimal(2.12321)))) + Seq(Row(BigDecimal("2.12321")))) checkAnswer( df.selectExpr("b * a - b"), - Seq(Row(BigDecimal(0.12321)))) + Seq(Row(BigDecimal("0.12321")))) checkAnswer( df.selectExpr("b * a * b"), Seq(Row(d))) @@ -1843,25 +1839,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } // Create paths with unusual characters - val specialCharacterPath = sql( - """ + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val specialCharacterPath = sql( + """ | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp """.stripMargin) - withTempView("specialCharacterTable") { - specialCharacterPath.createOrReplaceTempView("specialCharacterTable") - checkAnswer( - specialCharacterPath.select($"`r&&b.c`.*"), - nestedStructData.select($"record.*")) - checkAnswer( - sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), + withTempView("specialCharacterTable") { + specialCharacterPath.createOrReplaceTempView("specialCharacterTable") + checkAnswer( + specialCharacterPath.select($"`r&&b.c`.*"), + nestedStructData.select($"record.*")) + checkAnswer( + sql( + "SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), nestedStructData.select($"record.r1")) - checkAnswer( - sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), - nestedStructData.select($"record.r2")) - checkAnswer( - sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), - nestedStructData.select($"record.r1.*")) + checkAnswer( + sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), + nestedStructData.select($"record.r2")) + checkAnswer( + sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), + nestedStructData.select($"record.r1.*")) + } } // Try star expanding a scalar. This should fail. @@ -2619,4 +2618,63 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { new URL(jarFromInvalidFs) } } + + test("RuntimeReplaceable functions should not take extra parameters") { + val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)")) + assert(e.message.contains("Invalid number of arguments")) + } + + test("SPARK-21228: InSet incorrect handling of structs") { + withTempView("A") { + // reduce this from the default of 10 so the repro query text is not too long + withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) { + // a relation that has 1 column of struct type with values (1,1), ..., (9, 9) + spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a") + .createOrReplaceTempView("A") + val df = sql( + """ + |SELECT * from + | (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows + | -- the IN will become InSet with a Set of GenericInternalRows + | -- a GenericInternalRow is never equal to an UnsafeRow so the query would + | -- returns 0 results, which is incorrect + | WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L), + | NAMED_STRUCT('a', 3L, 'b', 3L)) + """.stripMargin) + checkAnswer(df, Row(Row(1, 1))) + } + } + } + + test("SPARK-21335: support un-aliased subquery") { + withTempView("v") { + Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("v") + checkAnswer(sql("SELECT i from (SELECT i FROM v)"), Row(1)) + + val e = intercept[AnalysisException](sql("SELECT v.i from (SELECT i FROM v)")) + assert(e.message == + "cannot resolve '`v.i`' given input columns: [__auto_generated_subquery_name.i]") + + checkAnswer(sql("SELECT __auto_generated_subquery_name.i from (SELECT i FROM v)"), Row(1)) + } + } + + test("SPARK-21743: top-most limit should not cause memory leak") { + // In unit test, Spark will fail the query if memory leak detected. + spark.range(100).groupBy("id").count().limit(1).collect() + } + + test("SPARK-21652: rule confliction of InferFiltersFromConstraints and ConstantPropagation") { + withTempView("t1", "t2") { + Seq((1, 1)).toDF("col1", "col2").createOrReplaceTempView("t1") + Seq(1, 2).toDF("col").createOrReplaceTempView("t2") + val df = sql( + """ + |SELECT * + |FROM t1, t2 + |WHERE t1.col1 = 1 AND 1 = t1.col2 AND t1.col1 = t2.col AND t1.col2 = t2.col + """.stripMargin) + checkAnswer(df, Row(1, 1, 1)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index d9130fdcfaea..e3901af4b998 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} -import org.apache.spark.sql.execution.command.DescribeTableCommand +import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -214,11 +214,11 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { /** Executes a query and returns the result as (schema of the output, normalized output). */ private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = { // Returns true if the plan is supposed to be sorted. - def needSort(plan: LogicalPlan): Boolean = plan match { + def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false - case _: DescribeTableCommand => true + case _: DescribeTableCommand | _: DescribeColumnCommand => true case PhysicalOperation(_, _, Sort(_, true, _)) => true - case _ => plan.children.iterator.exists(needSort) + case _ => plan.children.iterator.exists(isSorted) } try { @@ -228,11 +228,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Get answer, but also get rid of the #1234 expression ids that show up in explain plans val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") - .replaceAll("Created.*", s"Created $notIncludedMsg") + .replaceAll("Created By.*", s"Created By $notIncludedMsg") + .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")) // If the output is not pre-sorted, sort it. - if (needSort(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) + if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) } catch { case a: AnalysisException => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 5638c8eeda84..c01666770720 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution @@ -71,10 +72,10 @@ class SessionStateSuite extends SparkFunSuite } test("fork new session and inherit function registry and udf") { - val testFuncName1 = "strlenScala" - val testFuncName2 = "addone" + val testFuncName1 = FunctionIdentifier("strlenScala") + val testFuncName2 = FunctionIdentifier("addone") try { - activeSession.udf.register(testFuncName1, (_: String).length + (_: Int)) + activeSession.udf.register(testFuncName1.funcName, (_: String).length + (_: Int)) val forkedSession = activeSession.cloneSession() // inheritance @@ -86,7 +87,7 @@ class SessionStateSuite extends SparkFunSuite // independence forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1) assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) - activeSession.udf.register(testFuncName2, (_: Int) + 1) + activeSession.udf.register(testFuncName2.funcName, (_: Int) + 1) assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty) } finally { activeSession.sessionState.functionRegistry.dropFunction(testFuncName1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 386d13d07a95..c0301f2ce2d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -17,49 +17,49 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.internal.SQLConf /** * Test cases for the builder pattern of [[SparkSession]]. */ -class SparkSessionBuilderSuite extends SparkFunSuite { +class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { - private var initialSession: SparkSession = _ + override def afterEach(): Unit = { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + } - private lazy val sparkContext: SparkContext = { - initialSession = SparkSession.builder() + test("create with config options and propagate them to SparkContext and SparkSession") { + val session = SparkSession.builder() .master("local") .config("spark.ui.enabled", value = false) .config("some-config", "v2") .getOrCreate() - initialSession.sparkContext - } - - test("create with config options and propagate them to SparkContext and SparkSession") { - // Creating a new session with config - this works by just calling the lazy val - sparkContext - assert(initialSession.sparkContext.conf.get("some-config") == "v2") - assert(initialSession.conf.get("some-config") == "v2") - SparkSession.clearDefaultSession() + assert(session.sparkContext.conf.get("some-config") == "v2") + assert(session.conf.get("some-config") == "v2") } test("use global default session") { - val session = SparkSession.builder().getOrCreate() + val session = SparkSession.builder().master("local").getOrCreate() assert(SparkSession.builder().getOrCreate() == session) - SparkSession.clearDefaultSession() } test("config options are propagated to existing SparkSession") { - val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate() + val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate() assert(session1 == session2) assert(session1.conf.get("spark-config1") == "b") - SparkSession.clearDefaultSession() } test("use session from active thread session and propagate config options") { - val defaultSession = SparkSession.builder().getOrCreate() + val defaultSession = SparkSession.builder().master("local").getOrCreate() val activeSession = defaultSession.newSession() SparkSession.setActiveSession(activeSession) val session = SparkSession.builder().config("spark-config2", "a").getOrCreate() @@ -67,19 +67,19 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(activeSession != defaultSession) assert(session == activeSession) assert(session.conf.get("spark-config2") == "a") + assert(session.sessionState.conf == SQLConf.get) + assert(SQLConf.get.getConfString("spark-config2") == "a") SparkSession.clearActiveSession() assert(SparkSession.builder().getOrCreate() == defaultSession) - SparkSession.clearDefaultSession() } test("create a new session if the default session has been stopped") { - val defaultSession = SparkSession.builder().getOrCreate() + val defaultSession = SparkSession.builder().master("local").getOrCreate() SparkSession.setDefaultSession(defaultSession) defaultSession.stop() val newSession = SparkSession.builder().master("local").getOrCreate() assert(newSession != defaultSession) - newSession.stop() } test("create a new session if the active thread session has been stopped") { @@ -88,27 +88,38 @@ class SparkSessionBuilderSuite extends SparkFunSuite { activeSession.stop() val newSession = SparkSession.builder().master("local").getOrCreate() assert(newSession != activeSession) - newSession.stop() } test("create SparkContext first then SparkSession") { - sparkContext.stop() val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") val sparkContext2 = new SparkContext(conf) val session = SparkSession.builder().config("key2", "value2").getOrCreate() assert(session.conf.get("key1") == "value1") assert(session.conf.get("key2") == "value2") + assert(session.sparkContext == sparkContext2) + // We won't update conf for existing `SparkContext` + assert(!sparkContext2.conf.contains("key2")) + assert(sparkContext2.conf.get("key1") == "value1") + } + + test("create SparkContext first then pass context to SparkSession") { + val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") + val newSC = new SparkContext(conf) + val session = SparkSession.builder().sparkContext(newSC).config("key2", "value2").getOrCreate() + assert(session.conf.get("key1") == "value1") + assert(session.conf.get("key2") == "value2") + assert(session.sparkContext == newSC) assert(session.sparkContext.conf.get("key1") == "value1") - assert(session.sparkContext.conf.get("key2") == "value2") + // If the created sparkContext is passed through the Builder's API sparkContext, + // the conf of this sparkContext will not contain the conf set through the API config. + assert(!session.sparkContext.conf.contains("key2")) assert(session.sparkContext.conf.get("spark.app.name") == "test") - session.stop() } test("SPARK-15887: hive-site.xml should be loaded") { val session = SparkSession.builder().master("local").getOrCreate() assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") assert(session.sparkContext.hadoopConfiguration.get("hive.in.test") == "true") - session.stop() } test("SPARK-15991: Set global Hadoop conf") { @@ -120,7 +131,6 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(session.sessionState.newHadoopConf().get(mySpecialKey) == mySpecialValue) } finally { session.sparkContext.hadoopConfiguration.unset(mySpecialKey) - session.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index ddc393c8da05..2fc92f4aff92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -17,19 +17,12 @@ package org.apache.spark.sql -import java.{lang => jl} -import java.sql.{Date, Timestamp} - import scala.collection.mutable -import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.internal.StaticSQLConf -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ @@ -40,17 +33,6 @@ import org.apache.spark.sql.types._ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext { import testImplicits._ - private def checkTableStats(tableName: String, expectedRowCount: Option[Int]) - : Option[CatalogStatistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } - test("estimates the size of a limit 0 on outer join") { withTempView("test") { Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") @@ -60,7 +42,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = df1.join(df2, Seq("k"), "left") val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.stats(conf).sizeInBytes + g.stats.sizeInBytes } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") @@ -69,6 +51,50 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("analyzing views is not supported") { + def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { + val err = intercept[AnalysisException] { + sql(analyzeCommand) + } + assert(err.message.contains("ANALYZE TABLE is not supported")) + } + + val tableName = "tbl" + withTable(tableName) { + spark.range(10).write.saveAsTable(tableName) + val viewName = "view" + withView(viewName) { + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + } + } + } + + test("statistics collection of a table with zero column") { + val table_no_cols = "table_no_cols" + withTable(table_no_cols) { + val rddNoCols = sparkContext.parallelize(1 to 10).map(_ => Row.empty) + val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) + dfNoCols.write.format("json").saveAsTable(table_no_cols) + sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") + checkTableStats(table_no_cols, hasSizeInBytes = true, expectedRowCounts = Some(10)) + } + } + + test("analyze empty table") { + val table = "emptyTable" + withTable(table) { + sql(s"CREATE TABLE $table (key STRING, value STRING) USING PARQUET") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS noscan") + val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetchedStats1.get.sizeInBytes == 0) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetchedStats2.get.sizeInBytes == 0) + } + } + test("analyze column command - unsupported types and invalid columns") { val tableName = "column_stats_test1" withTable(tableName) { @@ -96,20 +122,20 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - checkTableStats(tableName, expectedRowCount = None) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") - checkTableStats(tableName, expectedRowCount = Some(2)) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = Some(2)) } } test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.selectExpr("a").queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) } @@ -150,157 +176,176 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("number format in statistics") { val numbers = Seq( - BigInt(0) -> ("0.0 B", "0"), - BigInt(100) -> ("100.0 B", "100"), - BigInt(2047) -> ("2047.0 B", "2.05E+3"), - BigInt(2048) -> ("2.0 KB", "2.05E+3"), - BigInt(3333333) -> ("3.2 MB", "3.33E+6"), - BigInt(4444444444L) -> ("4.1 GB", "4.44E+9"), - BigInt(5555555555555L) -> ("5.1 TB", "5.56E+12"), - BigInt(6666666666666666L) -> ("5.9 PB", "6.67E+15"), - BigInt(1L << 10 ) * (1L << 60) -> ("1024.0 EB", "1.18E+21"), - BigInt(1L << 11) * (1L << 60) -> ("2.36E+21 B", "2.36E+21") + BigInt(0) -> (("0.0 B", "0")), + BigInt(100) -> (("100.0 B", "100")), + BigInt(2047) -> (("2047.0 B", "2.05E+3")), + BigInt(2048) -> (("2.0 KB", "2.05E+3")), + BigInt(3333333) -> (("3.2 MB", "3.33E+6")), + BigInt(4444444444L) -> (("4.1 GB", "4.44E+9")), + BigInt(5555555555555L) -> (("5.1 TB", "5.56E+12")), + BigInt(6666666666666666L) -> (("5.9 PB", "6.67E+15")), + BigInt(1L << 10 ) * (1L << 60) -> (("1024.0 EB", "1.18E+21")), + BigInt(1L << 11) * (1L << 60) -> (("2.36E+21 B", "2.36E+21")) ) numbers.foreach { case (input, (expectedSize, expectedRows)) => val stats = Statistics(sizeInBytes = input, rowCount = Some(input)) val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," + - s" isBroadcastable=${stats.isBroadcastable}" + s" hints=none" assert(stats.simpleString == expectedString) } } -} - -/** - * The base for test cases that we want to include in both the hive module (for verifying behavior - * when using the Hive external catalog) as well as in the sql/core module. - */ -abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { - import testImplicits._ - - private val dec1 = new java.math.BigDecimal("1.000000000000000000") - private val dec2 = new java.math.BigDecimal("8.000000000000000000") - private val d1 = Date.valueOf("2016-05-08") - private val d2 = Date.valueOf("2016-05-09") - private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") - - /** - * Define a very simple 3 row table used for testing column serialization. - * Note: last column is seq[int] which doesn't support stats collection. - */ - protected val data = Seq[ - (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, - jl.Double, jl.Float, java.math.BigDecimal, - String, Array[Byte], Date, Timestamp, - Seq[Int])]( - (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), - (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), - (null, null, null, null, null, null, null, null, null, null, null, null, null) - ) - - /** A mapping from column to the stats collected. */ - protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), - Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), - Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) - ) - - private val randomName = new Random(31) - - /** - * Compute column stats for the given DataFrame and compare it with colStats. - */ - def checkColStats( - df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { - val tableName = "column_stats_test_" + randomName.nextInt(1000) - withTable(tableName) { - df.write.saveAsTable(tableName) - - // Collect statistics - sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + - colStats.keys.mkString(", ")) + test("change stats after truncate command") { + val table = "change_stats_truncate_table" + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // truncate table command + sql(s"TRUNCATE TABLE $table") + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched2.get.sizeInBytes == 0) + assert(fetched2.get.colStats.isEmpty) + } + } - // Validate statistics - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) + test("change stats after set location command") { + val table = "change_stats_set_location_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // set location command + val initLocation = spark.sessionState.catalog.getTableMetadata(TableIdentifier(table)) + .storage.locationUri.get.toString + withTempDir { newLocation => + sql(s"ALTER TABLE $table SET LOCATION '${newLocation.toURI.toString}'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes == 0) + assert(fetched2.get.colStats.isEmpty) + + // set back to the initial location + sql(s"ALTER TABLE $table SET LOCATION '$initLocation'") + val fetched3 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched3.get.sizeInBytes == fetched1.get.sizeInBytes) + } else { + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } + } + } + } - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) + test("change stats after insert command for datasource table") { + val table = "change_stats_insert_datasource_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // table lookup will make the table cached + spark.table(table) + assert(isTableInCatalogCache(table)) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + } else { + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + + // check that tableRelationCache inside the catalog was invalidated after insert + assert(!isTableInCatalogCache(table)) } } } } - // This test will be run twice: with and without Hive support - test("SPARK-18856: non-empty partitioned table should not report zero size") { - withTable("ds_tbl", "hive_tbl") { - spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") - val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf) - assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") - - if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { - sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") - sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") - val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf) - assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + test("invalidation of tableRelationCache after inserts") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).write.saveAsTable(table) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + val initialSizeInBytes = getTableFromCatalogCache(table).stats.sizeInBytes + spark.range(100).write.mode(SaveMode.Append).saveAsTable(table) + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 2 * initialSizeInBytes) + } } } } - // This test will be run twice: with and without Hive support - test("conversion from CatalogStatistics to Statistics") { - withTable("ds_tbl", "hive_tbl") { - // Test data source table - checkStatsConversion(tableName = "ds_tbl", isDatasourceTable = true) - // Test hive serde table - if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { - checkStatsConversion(tableName = "hive_tbl", isDatasourceTable = false) + test("invalidation of tableRelationCache after table truncation") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).write.saveAsTable(table) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + sql(s"TRUNCATE TABLE $table") + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 0) + } } } } - private def checkStatsConversion(tableName: String, isDatasourceTable: Boolean): Unit = { - // Create an empty table and run analyze command on it. - val createTableSql = if (isDatasourceTable) { - s"CREATE TABLE $tableName (c1 INT, c2 STRING) USING PARQUET" - } else { - s"CREATE TABLE $tableName (c1 INT, c2 STRING)" + test("invalidation of tableRelationCache after alter table add partition") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTempDir { dir => + withTable(table) { + val path = dir.getCanonicalPath + sql(s""" + |CREATE TABLE $table (col1 int, col2 int) + |USING PARQUET + |PARTITIONED BY (col2) + |LOCATION '${dir.toURI}'""".stripMargin) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 0) + spark.catalog.recoverPartitions(table) + val df = Seq((1, 2), (1, 2)).toDF("col2", "col1") + df.write.parquet(s"$path/col2=1") + sql(s"ALTER TABLE $table ADD PARTITION (col2=1) LOCATION '${dir.toURI}'") + spark.table(table) + val cachedTable = getTableFromCatalogCache(table) + val cachedTableSizeInBytes = cachedTable.stats.sizeInBytes + val defaultSizeInBytes = conf.defaultSizeInBytes + if (autoUpdate) { + assert(cachedTableSizeInBytes != defaultSizeInBytes && cachedTableSizeInBytes > 0) + } else { + assert(cachedTableSizeInBytes == defaultSizeInBytes) + } + } + } + } } - sql(createTableSql) - // Analyze only one column. - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") - val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { - case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta) - case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) - }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) - // Check catalog statistics - assert(catalogTable.stats.isDefined) - assert(catalogTable.stats.get.sizeInBytes == 0) - assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) - - // Check relation statistics - assert(relation.stats(conf).sizeInBytes == 0) - assert(relation.stats(conf).rowCount == Some(0)) - assert(relation.stats(conf).attributeStats.size == 1) - val (attribute, colStat) = relation.stats(conf).attributeStats.head - assert(attribute.name == "c1") - assert(colStat == emptyColStat) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala new file mode 100644 index 000000000000..a2f63edd786b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.{lang => jl} +import java.sql.{Date, Timestamp} + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.Decimal + + +/** + * The base for statistics test cases that we want to include in both the hive module (for + * verifying behavior when using the Hive external catalog) as well as in the sql/core module. + */ +abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { + import testImplicits._ + + private val dec1 = new java.math.BigDecimal("1.000000000000000000") + private val dec2 = new java.math.BigDecimal("8.000000000000000000") + private val d1 = Date.valueOf("2016-05-08") + private val d2 = Date.valueOf("2016-05-09") + private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") + private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + + /** + * Define a very simple 3 row table used for testing column serialization. + * Note: last column is seq[int] which doesn't support stats collection. + */ + protected val data = Seq[ + (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, + jl.Double, jl.Float, java.math.BigDecimal, + String, Array[Byte], Date, Timestamp, + Seq[Int])]( + (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), + (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), + (null, null, null, null, null, null, null, null, null, null, null, null, null) + ) + + /** A mapping from column to the stats collected. */ + protected val stats = mutable.LinkedHashMap( + "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), + "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), + "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), + "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), + "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), + "cstring" -> ColumnStat(2, None, None, 1, 3, 3), + "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), + "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), + Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), + Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) + ) + + private val randomName = new Random(31) + + def getCatalogTable(tableName: String): CatalogTable = { + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + } + + def getTableFromCatalogCache(tableName: String): LogicalPlan = { + val catalog = spark.sessionState.catalog + val qualifiedTableName = QualifiedTableName(catalog.getCurrentDatabase, tableName) + catalog.getCachedTable(qualifiedTableName) + } + + def isTableInCatalogCache(tableName: String): Boolean = { + getTableFromCatalogCache(tableName) != null + } + + def getCatalogStatistics(tableName: String): CatalogStatistics = { + getCatalogTable(tableName).stats.get + } + + def checkTableStats( + tableName: String, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = getCatalogTable(tableName).stats + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes >= 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + + stats + } + + /** + * Compute column stats for the given DataFrame and compare it with colStats. + */ + def checkColStats( + df: DataFrame, + colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + val tableName = "column_stats_test_" + randomName.nextInt(1000) + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + + colStats.keys.mkString(", ")) + + // Validate statistics + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) + } + } + } + } + + // This test will be run twice: with and without Hive support + test("SPARK-18856: non-empty partitioned table should not report zero size") { + withTable("ds_tbl", "hive_tbl") { + spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats + assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") + sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats + assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + } + } + } + + // This test will be run twice: with and without Hive support + test("conversion from CatalogStatistics to Statistics") { + withTable("ds_tbl", "hive_tbl") { + // Test data source table + checkStatsConversion(tableName = "ds_tbl", isDatasourceTable = true) + // Test hive serde table + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + checkStatsConversion(tableName = "hive_tbl", isDatasourceTable = false) + } + } + } + + private def checkStatsConversion(tableName: String, isDatasourceTable: Boolean): Unit = { + // Create an empty table and run analyze command on it. + val createTableSql = if (isDatasourceTable) { + s"CREATE TABLE $tableName (c1 INT, c2 STRING) USING PARQUET" + } else { + s"CREATE TABLE $tableName (c1 INT, c2 STRING)" + } + sql(createTableSql) + // Analyze only one column. + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") + val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { + case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta) + case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) + }.head + val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + // Check catalog statistics + assert(catalogTable.stats.isDefined) + assert(catalogTable.stats.get.sizeInBytes == 0) + assert(catalogTable.stats.get.rowCount == Some(0)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + + // Check relation statistics + assert(relation.stats.sizeInBytes == 0) + assert(relation.stats.rowCount == Some(0)) + assert(relation.stats.attributeStats.size == 1) + val (attribute, colStat) = relation.stats.attributeStats.head + assert(attribute.name == "c1") + assert(colStat == emptyColStat) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index bcc235104995..3d76b9ac33e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -161,12 +161,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string trim functions") { - val df = Seq((" example ", "")).toDF("a", "b") + val df = Seq((" example ", "", "example")).toDF("a", "b", "c") checkAnswer( df.select(ltrim($"a"), rtrim($"a"), trim($"a")), Row("example ", " example", "example")) + checkAnswer( + df.select(ltrim($"c", "e"), rtrim($"c", "e"), trim($"c", "e")), + Row("xample", "exampl", "xampl")) + + checkAnswer( + df.select(ltrim($"c", "xe"), rtrim($"c", "emlp"), trim($"c", "elxp")), + Row("ample", "exa", "am")) + + checkAnswer( + df.select(trim($"c", "xyz")), + Row("example")) + checkAnswer( df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), Row("example ", " example", "example")) @@ -387,7 +399,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("6.4817")) checkAnswer( - df.select(format_number(lit(BigDecimal(7.128381)), 4)), // not convert anything + df.select(format_number(lit(BigDecimal("7.128381")), 4)), // not convert anything Row("7.1284")) intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 131abf7c1e5d..8673dc14f759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { @@ -27,23 +28,23 @@ class SubquerySuite extends QueryTest with SharedSQLContext { val row = identity[(java.lang.Integer, java.lang.Double)](_) lazy val l = Seq( - row(1, 2.0), - row(1, 2.0), - row(2, 1.0), - row(2, 1.0), - row(3, 3.0), - row(null, null), - row(null, 5.0), - row(6, null)).toDF("a", "b") + row((1, 2.0)), + row((1, 2.0)), + row((2, 1.0)), + row((2, 1.0)), + row((3, 3.0)), + row((null, null)), + row((null, 5.0)), + row((6, null))).toDF("a", "b") lazy val r = Seq( - row(2, 3.0), - row(2, 3.0), - row(3, 2.0), - row(4, 1.0), - row(null, null), - row(null, 5.0), - row(6, null)).toDF("c", "d") + row((2, 3.0)), + row((2, 3.0)), + row((3, 2.0)), + row((4, 1.0)), + row((null, null)), + row((null, 5.0)), + row((6, null))).toDF("c", "d") lazy val t = r.filter($"c".isNotNull && $"d".isNotNull) @@ -72,7 +73,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } } - test("rdd deserialization does not crash [SPARK-15791]") { + test("SPARK-15791: rdd deserialization does not crash") { sql("select (select 1 as b) as b").rdd.count() } @@ -517,7 +518,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") } - assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated")) + assert(msg1.getMessage.contains("Correlated scalar subqueries must be aggregated")) val msg2 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") @@ -655,7 +656,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { """ | select c1 from onerow t1 | where exists (select 1 - | from (select 1 from onerow t2 LIMIT 1) + | from (select c1 from onerow t2 LIMIT 1) t2 | where t1.c1=t2.c1)""".stripMargin), Row(1) :: Nil) } @@ -867,4 +868,86 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"), Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil) } + + test("SPARK-20688: correctly check analysis for scalar sub-queries") { + withTempView("t") { + Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("t") + val e = intercept[AnalysisException](sql("SELECT (SELECT count(*) FROM t WHERE a = 1)")) + assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]")) + } + } + + test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 1") { + withTable("t1") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}'") + + val sqlText = + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin + val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan + val join = optimizedPlan.collectFirst { case j: Join => j }.get + assert(join.duplicateResolved) + assert(optimizedPlan.resolved) + } + } + } + + test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 2") { + withTable("t1", "t2", "t3") { + withTempPath { path => + val data = Seq((1, 1, 1), (2, 0, 2)) + + data.toDF("t1a", "t1b", "t1c").write.parquet(path.getCanonicalPath + "/t1") + data.toDF("t2a", "t2b", "t2c").write.parquet(path.getCanonicalPath + "/t2") + data.toDF("t3a", "t3b", "t3c").write.parquet(path.getCanonicalPath + "/t3") + + sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}/t1'") + sql(s"CREATE TABLE t2 USING parquet LOCATION '${path.toURI}/t2'") + sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}/t3'") + + val sqlText = + s""" + |SELECT * + |FROM (SELECT * + | FROM t2 + | WHERE t2c IN (SELECT t1c + | FROM t1 + | WHERE t1a = t2a) + | UNION + | SELECT * + | FROM t3 + | WHERE t3a IN (SELECT t2a + | FROM t2 + | UNION ALL + | SELECT t1a + | FROM t1 + | WHERE t1b > 0)) t4 + |WHERE t4.t2b IN (SELECT Min(t3b) + | FROM t3 + | WHERE t4.t2a = t3a) + """.stripMargin + val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan + val joinNodes = optimizedPlan.collect { case j: Join => j } + joinNodes.foreach(j => assert(j.duplicateResolved)) + assert(optimizedPlan.resolved) + } + } + } + + test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 3") { + val sqlText = + """ + |SELECT * FROM l, r WHERE l.a = r.c + 1 AND + |(EXISTS (SELECT * FROM r) OR l.a = r.c) + """.stripMargin + val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan + val join = optimizedPlan.collectFirst { case j: Join => j }.get + assert(join.duplicateResolved) + assert(optimizedPlan.resolved) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala new file mode 100644 index 000000000000..e47d4b0ee25d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +/** + * This test suite ensures all the TPC-DS queries can be successfully analyzed and optimized + * without hitting the max iteration threshold. + */ +class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { + + // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting + // the max iteration of analyzer/optimizer batches. + assert(Utils.isTesting, "spark.testing is not set to true") + + /** + * Drop all the tables + */ + protected override def afterAll(): Unit = { + try { + spark.sessionState.catalog.reset() + } finally { + super.afterAll() + } + } + + override def beforeAll() { + super.beforeAll() + sql( + """ + |CREATE TABLE `catalog_page` ( + |`cp_catalog_page_sk` INT, `cp_catalog_page_id` STRING, `cp_start_date_sk` INT, + |`cp_end_date_sk` INT, `cp_department` STRING, `cp_catalog_number` INT, + |`cp_catalog_page_number` INT, `cp_description` STRING, `cp_type` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `catalog_returns` ( + |`cr_returned_date_sk` INT, `cr_returned_time_sk` INT, `cr_item_sk` INT, + |`cr_refunded_customer_sk` INT, `cr_refunded_cdemo_sk` INT, `cr_refunded_hdemo_sk` INT, + |`cr_refunded_addr_sk` INT, `cr_returning_customer_sk` INT, `cr_returning_cdemo_sk` INT, + |`cr_returning_hdemo_sk` INT, `cr_returning_addr_sk` INT, `cr_call_center_sk` INT, + |`cr_catalog_page_sk` INT, `cr_ship_mode_sk` INT, `cr_warehouse_sk` INT, `cr_reason_sk` INT, + |`cr_order_number` INT, `cr_return_quantity` INT, `cr_return_amount` DECIMAL(7,2), + |`cr_return_tax` DECIMAL(7,2), `cr_return_amt_inc_tax` DECIMAL(7,2), `cr_fee` DECIMAL(7,2), + |`cr_return_ship_cost` DECIMAL(7,2), `cr_refunded_cash` DECIMAL(7,2), + |`cr_reversed_charge` DECIMAL(7,2), `cr_store_credit` DECIMAL(7,2), + |`cr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer` ( + |`c_customer_sk` INT, `c_customer_id` STRING, `c_current_cdemo_sk` INT, + |`c_current_hdemo_sk` INT, `c_current_addr_sk` INT, `c_first_shipto_date_sk` INT, + |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, + |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, + |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, + |`c_email_address` STRING, `c_last_review_date` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer_address` ( + |`ca_address_sk` INT, `ca_address_id` STRING, `ca_street_number` STRING, + |`ca_street_name` STRING, `ca_street_type` STRING, `ca_suite_number` STRING, + |`ca_city` STRING, `ca_county` STRING, `ca_state` STRING, `ca_zip` STRING, + |`ca_country` STRING, `ca_gmt_offset` DECIMAL(5,2), `ca_location_type` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer_demographics` ( + |`cd_demo_sk` INT, `cd_gender` STRING, `cd_marital_status` STRING, + |`cd_education_status` STRING, `cd_purchase_estimate` INT, `cd_credit_rating` STRING, + |`cd_dep_count` INT, `cd_dep_employed_count` INT, `cd_dep_college_count` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `date_dim` ( + |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, + |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, + |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, + |`d_weekend` STRING, `d_following_holiday` STRING, `d_first_dom` INT, `d_last_dom` INT, + |`d_same_day_ly` INT, `d_same_day_lq` INT, `d_current_day` STRING, `d_current_week` STRING, + |`d_current_month` STRING, `d_current_quarter` STRING, `d_current_year` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `household_demographics` ( + |`hd_demo_sk` INT, `hd_income_band_sk` INT, `hd_buy_potential` STRING, `hd_dep_count` INT, + |`hd_vehicle_count` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `inventory` (`inv_date_sk` INT, `inv_item_sk` INT, `inv_warehouse_sk` INT, + |`inv_quantity_on_hand` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, + |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, + |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, + |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, + |`i_units` STRING, `i_container` STRING, `i_manager_id` INT, `i_product_name` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `promotion` ( + |`p_promo_sk` INT, `p_promo_id` STRING, `p_start_date_sk` INT, `p_end_date_sk` INT, + |`p_item_sk` INT, `p_cost` DECIMAL(15,2), `p_response_target` INT, `p_promo_name` STRING, + |`p_channel_dmail` STRING, `p_channel_email` STRING, `p_channel_catalog` STRING, + |`p_channel_tv` STRING, `p_channel_radio` STRING, `p_channel_press` STRING, + |`p_channel_event` STRING, `p_channel_demo` STRING, `p_channel_details` STRING, + |`p_purpose` STRING, `p_discount_active` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store` ( + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, + |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, + |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, + |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, + |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING, + |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING, + |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING, + |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store_returns` ( + |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, + |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, + |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, + |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), + |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), + |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), + |`sr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `catalog_sales` ( + |`cs_sold_date_sk` INT, `cs_sold_time_sk` INT, `cs_ship_date_sk` INT, + |`cs_bill_customer_sk` INT, `cs_bill_cdemo_sk` INT, `cs_bill_hdemo_sk` INT, + |`cs_bill_addr_sk` INT, `cs_ship_customer_sk` INT, `cs_ship_cdemo_sk` INT, + |`cs_ship_hdemo_sk` INT, `cs_ship_addr_sk` INT, `cs_call_center_sk` INT, + |`cs_catalog_page_sk` INT, `cs_ship_mode_sk` INT, `cs_warehouse_sk` INT, + |`cs_item_sk` INT, `cs_promo_sk` INT, `cs_order_number` INT, `cs_quantity` INT, + |`cs_wholesale_cost` DECIMAL(7,2), `cs_list_price` DECIMAL(7,2), + |`cs_sales_price` DECIMAL(7,2), `cs_ext_discount_amt` DECIMAL(7,2), + |`cs_ext_sales_price` DECIMAL(7,2), `cs_ext_wholesale_cost` DECIMAL(7,2), + |`cs_ext_list_price` DECIMAL(7,2), `cs_ext_tax` DECIMAL(7,2), `cs_coupon_amt` DECIMAL(7,2), + |`cs_ext_ship_cost` DECIMAL(7,2), `cs_net_paid` DECIMAL(7,2), + |`cs_net_paid_inc_tax` DECIMAL(7,2), `cs_net_paid_inc_ship` DECIMAL(7,2), + |`cs_net_paid_inc_ship_tax` DECIMAL(7,2), `cs_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_sales` ( + |`ws_sold_date_sk` INT, `ws_sold_time_sk` INT, `ws_ship_date_sk` INT, `ws_item_sk` INT, + |`ws_bill_customer_sk` INT, `ws_bill_cdemo_sk` INT, `ws_bill_hdemo_sk` INT, + |`ws_bill_addr_sk` INT, `ws_ship_customer_sk` INT, `ws_ship_cdemo_sk` INT, + |`ws_ship_hdemo_sk` INT, `ws_ship_addr_sk` INT, `ws_web_page_sk` INT, `ws_web_site_sk` INT, + |`ws_ship_mode_sk` INT, `ws_warehouse_sk` INT, `ws_promo_sk` INT, `ws_order_number` INT, + |`ws_quantity` INT, `ws_wholesale_cost` DECIMAL(7,2), `ws_list_price` DECIMAL(7,2), + |`ws_sales_price` DECIMAL(7,2), `ws_ext_discount_amt` DECIMAL(7,2), + |`ws_ext_sales_price` DECIMAL(7,2), `ws_ext_wholesale_cost` DECIMAL(7,2), + |`ws_ext_list_price` DECIMAL(7,2), `ws_ext_tax` DECIMAL(7,2), + |`ws_coupon_amt` DECIMAL(7,2), `ws_ext_ship_cost` DECIMAL(7,2), `ws_net_paid` DECIMAL(7,2), + |`ws_net_paid_inc_tax` DECIMAL(7,2), `ws_net_paid_inc_ship` DECIMAL(7,2), + |`ws_net_paid_inc_ship_tax` DECIMAL(7,2), `ws_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store_sales` ( + |`ss_sold_date_sk` INT, `ss_sold_time_sk` INT, `ss_item_sk` INT, `ss_customer_sk` INT, + |`ss_cdemo_sk` INT, `ss_hdemo_sk` INT, `ss_addr_sk` INT, `ss_store_sk` INT, + |`ss_promo_sk` INT, `ss_ticket_number` INT, `ss_quantity` INT, + |`ss_wholesale_cost` DECIMAL(7,2), `ss_list_price` DECIMAL(7,2), + |`ss_sales_price` DECIMAL(7,2), `ss_ext_discount_amt` DECIMAL(7,2), + |`ss_ext_sales_price` DECIMAL(7,2), `ss_ext_wholesale_cost` DECIMAL(7,2), + |`ss_ext_list_price` DECIMAL(7,2), `ss_ext_tax` DECIMAL(7,2), + |`ss_coupon_amt` DECIMAL(7,2), `ss_net_paid` DECIMAL(7,2), + |`ss_net_paid_inc_tax` DECIMAL(7,2), `ss_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_returns` ( + |`wr_returned_date_sk` BIGINT, `wr_returned_time_sk` BIGINT, `wr_item_sk` BIGINT, + |`wr_refunded_customer_sk` BIGINT, `wr_refunded_cdemo_sk` BIGINT, + |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, + |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, + |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), + |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), + |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), + |`wr_reversed_charge` DECIMAL(7,2), `wr_account_credit` DECIMAL(7,2), + |`wr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_site` ( + |`web_site_sk` INT, `web_site_id` STRING, `web_rec_start_date` DATE, + |`web_rec_end_date` DATE, `web_name` STRING, `web_open_date_sk` INT, + |`web_close_date_sk` INT, `web_class` STRING, `web_manager` STRING, `web_mkt_id` INT, + |`web_mkt_class` STRING, `web_mkt_desc` STRING, `web_market_manager` STRING, + |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, + |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, + |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, + |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `reason` ( + |`r_reason_sk` INT, `r_reason_id` STRING, `r_reason_desc` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `call_center` ( + |`cc_call_center_sk` INT, `cc_call_center_id` STRING, `cc_rec_start_date` DATE, + |`cc_rec_end_date` DATE, `cc_closed_date_sk` INT, `cc_open_date_sk` INT, `cc_name` STRING, + |`cc_class` STRING, `cc_employees` INT, `cc_sq_ft` INT, `cc_hours` STRING, + |`cc_manager` STRING, `cc_mkt_id` INT, `cc_mkt_class` STRING, `cc_mkt_desc` STRING, + |`cc_market_manager` STRING, `cc_division` INT, `cc_division_name` STRING, `cc_company` INT, + |`cc_company_name` STRING, `cc_street_number` STRING, `cc_street_name` STRING, + |`cc_street_type` STRING, `cc_suite_number` STRING, `cc_city` STRING, `cc_county` STRING, + |`cc_state` STRING, `cc_zip` STRING, `cc_country` STRING, `cc_gmt_offset` DECIMAL(5,2), + |`cc_tax_percentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `warehouse` ( + |`w_warehouse_sk` INT, `w_warehouse_id` STRING, `w_warehouse_name` STRING, + |`w_warehouse_sq_ft` INT, `w_street_number` STRING, `w_street_name` STRING, + |`w_street_type` STRING, `w_suite_number` STRING, `w_city` STRING, `w_county` STRING, + |`w_state` STRING, `w_zip` STRING, `w_country` STRING, `w_gmt_offset` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `ship_mode` ( + |`sm_ship_mode_sk` INT, `sm_ship_mode_id` STRING, `sm_type` STRING, `sm_code` STRING, + |`sm_carrier` STRING, `sm_contract` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `income_band` ( + |`ib_income_band_sk` INT, `ib_lower_bound` INT, `ib_upper_bound` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `time_dim` ( + |`t_time_sk` INT, `t_time_id` STRING, `t_time` INT, `t_hour` INT, `t_minute` INT, + |`t_second` INT, `t_am_pm` STRING, `t_shift` STRING, `t_sub_shift` STRING, + |`t_meal_time` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_page` (`wp_web_page_sk` INT, `wp_web_page_id` STRING, + |`wp_rec_start_date` DATE, `wp_rec_end_date` DATE, `wp_creation_date_sk` INT, + |`wp_access_date_sk` INT, `wp_autogen_flag` STRING, `wp_customer_sk` INT, + |`wp_url` STRING, `wp_type` STRING, `wp_char_count` INT, `wp_link_count` INT, + |`wp_image_count` INT, `wp_max_ad_count` INT) + |USING parquet + """.stripMargin) + } + + val tpcdsQueries = Seq( + "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", + "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", "q30", + "q31", "q32", "q33", "q34", "q35", "q36", "q37", "q38", "q39a", "q39b", "q40", + "q41", "q42", "q43", "q44", "q45", "q46", "q47", "q48", "q49", "q50", + "q51", "q52", "q53", "q54", "q55", "q56", "q57", "q58", "q59", "q60", + "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68", "q69", "q70", + "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79", "q80", + "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", + "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + + tpcdsQueries.foreach { name => + val queryString = resourceToString(s"tpcds/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // Just check the plans can be properly generated + sql(queryString).queryExecution.executedPlan + } + } + } + + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries + val modifiedTPCDSQueries = Seq( + "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", + "q63", "q65", "q68", "q73", "q79", "q89", "q98", "ss_max") + + modifiedTPCDSQueries.foreach { name => + val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"modified-$name") { + // Just check the plans can be properly generated + sql(queryString).queryExecution.executedPlan + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ae6b2bc3753f..7f1c009ca6e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.types.DataTypes private case class FunctionResult(f1: String, f2: String) @@ -71,12 +74,21 @@ class UDFSuite extends QueryTest with SharedSQLContext { } } - test("error reporting for incorrect number of arguments") { + test("error reporting for incorrect number of arguments - builtin function") { val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } - assert(e.getMessage.contains("arguments")) + assert(e.getMessage.contains("Invalid number of arguments for function substr")) + } + + test("error reporting for incorrect number of arguments - udf") { + val df = spark.emptyDataFrame + val e = intercept[AnalysisException] { + spark.udf.register("foo", (_: String).length) + df.selectExpr("foo(2, 3, 4)") + } + assert(e.getMessage.contains("Invalid number of arguments for function foo")) } test("error reporting for undefined functions") { @@ -93,9 +105,29 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } - test("ZeroArgument UDF") { - spark.udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) + test("UDF defined using UserDefinedFunction") { + import functions.udf + val foo = udf((x: Int) => x + 1) + spark.udf.register("foo", foo) + assert(sql("select foo(5)").head().getInt(0) == 6) + } + + test("ZeroArgument non-deterministic UDF") { + val foo = udf(() => Math.random()) + spark.udf.register("random0", foo.asNondeterministic()) + val df = sql("SELECT random0()") + assert(df.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df.head().getDouble(0) >= 0.0) + + val foo1 = foo.asNondeterministic() + val df1 = testData.select(foo1()) + assert(df1.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df1.head().getDouble(0) >= 0.0) + + val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic() + val df2 = testData.select(bar()) + assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df2.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { @@ -256,10 +288,12 @@ class UDFSuite extends QueryTest with SharedSQLContext { val sparkPlan = spark.sessionState.executePlan(explain).executedPlan sparkPlan.executeCollect().map(_.getString(0).trim).headOption.getOrElse("") } - val udf1 = "myUdf1" - val udf2 = "myUdf2" - spark.udf.register(udf1, (n: Int) => { n + 1 }) - spark.udf.register(udf2, (n: Int) => { n * 1 }) - assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1(UDF:$udf2(1))")) + val udf1Name = "myUdf1" + val udf2Name = "myUdf2" + val udf1 = spark.udf.register(udf1Name, (n: Int) => n + 1) + val udf2 = spark.udf.register(udf2Name, (n: Int) => n * 1) + assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) + assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) + .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a32763db054f..a5f904c621e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -101,9 +101,22 @@ class UnsafeRowSuite extends SparkFunSuite { MemoryAllocator.UNSAFE.free(offheapRowPage) } } + val (bytesFromArrayBackedRowWithOffset, field0StringFromArrayBackedRowWithOffset) = { + val baos = new ByteArrayOutputStream() + val numBytes = arrayBackedUnsafeRow.getSizeInBytes + val bytesWithOffset = new Array[Byte](numBytes + 100) + System.arraycopy(arrayBackedUnsafeRow.getBaseObject.asInstanceOf[Array[Byte]], 0, + bytesWithOffset, 100, numBytes) + val arrayBackedRow = new UnsafeRow(arrayBackedUnsafeRow.numFields()) + arrayBackedRow.pointTo(bytesWithOffset, Platform.BYTE_ARRAY_OFFSET + 100, numBytes) + arrayBackedRow.writeToStream(baos, null) + (baos.toByteArray, arrayBackedRow.getString(0)) + } assert(bytesFromArrayBackedRow === bytesFromOffheapRow) assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow) + assert(bytesFromArrayBackedRow === bytesFromArrayBackedRowWithOffset) + assert(field0StringFromArrayBackedRow === field0StringFromArrayBackedRowWithOffset) } test("calling getDouble() and getFloat() on null columns") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index b096a6db8517..a08433ba794d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -203,12 +203,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT // Tests to make sure that all operators correctly convert types on the way out. test("Local UDTs") { - val df = Seq((1, new UDT.MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") - df.collect()(0).getAs[UDT.MyDenseVector](1) - df.take(1)(0).getAs[UDT.MyDenseVector](1) - df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[UDT.MyDenseVector](0) - df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0) - .getAs[UDT.MyDenseVector](0) + val vec = new UDT.MyDenseVector(Array(0.1, 1.0)) + val df = Seq((1, vec)).toDF("int", "vec") + assert(vec === df.collect()(0).getAs[UDT.MyDenseVector](1)) + assert(vec === df.take(1)(0).getAs[UDT.MyDenseVector](1)) + checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) + checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) } test("UDTs with JSON") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala index f54e23e3aa6c..7cfee4957557 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.api.r diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index f7f1ccea281c..423e1288e8dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -38,7 +38,7 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head - assert(rootPath.toString.contains(basePath.toString)) + assert(rootPath.toString.contains(dir.toURI.getPath.stripSuffix("/"))) assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName)) assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 06bce9a2400e..737eeb0af586 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -280,7 +280,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { Seq(Some(5), None).foreach { minNumPostShufflePartitions => val testNameNote = minNumPostShufflePartitions match { - case Some(numPartitions) => "(minNumPostShufflePartitions: 3)" + case Some(numPartitions) => "(minNumPostShufflePartitions: " + numPartitions + ")" case None => "" } @@ -300,13 +300,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = agg.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -314,7 +314,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -351,13 +351,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -365,7 +365,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -377,7 +377,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: complex query 1$testNameNote") { - val test = { spark: SparkSession => + val test: (SparkSession) => Unit = { spark: SparkSession => val df1 = spark .range(0, 1000, 1, numInputPartitions) @@ -407,13 +407,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 4) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -429,7 +429,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: complex query 2$testNameNote") { - val test = { spark: SparkSession => + val test: (SparkSession) => Unit = { spark: SparkSession => val df1 = spark .range(0, 1000, 1, numInputPartitions) @@ -459,13 +459,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 3) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 59eaf4d1c29b..aac8d56ba620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext @@ -31,7 +31,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ShuffleExchange(SinglePartition, plan), + plan => ShuffleExchangeExec(SinglePartition, plan), input.map(Row.fromTuple) ) } @@ -81,12 +81,12 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) - val exchange1 = ShuffleExchange(part1, plan) - val exchange2 = ShuffleExchange(part1, plan) + val exchange1 = ShuffleExchangeExec(part1, plan) + val exchange2 = ShuffleExchangeExec(part1, plan) val part2 = HashPartitioning(output, 2) - val exchange3 = ShuffleExchange(part2, plan) + val exchange3 = ShuffleExchangeExec(part2, plan) val part3 = HashPartitioning(output ++ output, 2) - val exchange4 = ShuffleExchange(part3, plan) + val exchange4 = ShuffleExchangeExec(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4) assert(exchange1 sameResult exchange1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 00c5f2550cbb..efe28afab08e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -67,7 +67,10 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray( + ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + numSpillThreshold) + rows.foreach(x => array.add(x)) val iterator = array.generateIterator() @@ -130,7 +133,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { false)) val unsafeRow = new UnsafeRow(1) - val iter = array.getIterator + val iter = array.getIterator(0) while (iter.hasNext) { iter.loadNext() unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) @@ -143,7 +146,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold) rows.foreach(x => array.add(x)) val iterator = array.generateIterator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index 53c41639942b..ecc7264d7944 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -31,7 +31,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar override def afterAll(): Unit = TaskContext.unset() - private def withExternalArray(spillThreshold: Int) + private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { sc = new SparkContext("local", "test", new SparkConf(false)) @@ -45,6 +45,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar taskContext, 1024, SparkEnv.get.memoryManager.pageSizeBytes, + inMemoryThreshold, spillThreshold) try f(array) finally { array.clear() @@ -109,9 +110,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar assert(getNumBytesSpilled > 0) } - test("insert rows less than the spillThreshold") { - val spillThreshold = 100 - withExternalArray(spillThreshold) { array => + test("insert rows less than the inMemoryThreshold") { + val (inMemoryThreshold, spillThreshold) = (100, 50) + withExternalArray(inMemoryThreshold, spillThreshold) { array => assert(array.isEmpty) val expectedValues = populateRows(array, 1) @@ -122,8 +123,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) // Verify that NO spill has happened - populateRows(array, spillThreshold - 1, expectedValues) - assert(array.length == spillThreshold) + populateRows(array, inMemoryThreshold - 1, expectedValues) + assert(array.length == inMemoryThreshold) assertNoSpill() val iterator2 = validateData(array, expectedValues) @@ -133,20 +134,42 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } } - test("insert rows more than the spillThreshold to force spill") { - val spillThreshold = 100 - withExternalArray(spillThreshold) { array => - val numValuesInserted = 20 * spillThreshold - + test("insert rows more than the inMemoryThreshold but less than spillThreshold") { + val (inMemoryThreshold, spillThreshold) = (10, 50) + withExternalArray(inMemoryThreshold, spillThreshold) { array => assert(array.isEmpty) - val expectedValues = populateRows(array, 1) - assert(array.length == 1) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) + val iterator1 = validateData(array, expectedValues) + assertNoSpill() + + // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a + // spill to happen. Verify that NO spill has happened + populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues) + assert(array.length == spillThreshold - 1) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + } + + test("insert rows enough to force spill") { + val (inMemoryThreshold, spillThreshold) = (20, 10) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + assert(array.isEmpty) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) val iterator1 = validateData(array, expectedValues) + assertNoSpill() - // Populate more rows to trigger spill. Verify that spill has happened - populateRows(array, numValuesInserted - 1, expectedValues) - assert(array.length == numValuesInserted) + // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen. + // Verify that spill has happened + populateRows(array, 2, expectedValues) + assert(array.length == inMemoryThreshold + 1) assertSpill() val iterator2 = validateData(array, expectedValues) @@ -158,7 +181,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("iterator on an empty array should be empty") { - withExternalArray(spillThreshold = 10) { array => + withExternalArray(inMemoryThreshold = 4, spillThreshold = 10) { array => val iterator = array.generateIterator() assert(array.isEmpty) assert(array.length == 0) @@ -167,7 +190,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with negative start index") { - withExternalArray(spillThreshold = 2) { array => + withExternalArray(inMemoryThreshold = 100, spillThreshold = 56) { array => val exception = intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) @@ -178,8 +201,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with start index exceeding array's size (without spill)") { - val spillThreshold = 2 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => populateRows(array, spillThreshold / 2) val exception = @@ -191,8 +214,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with start index exceeding array's size (with spill)") { - val spillThreshold = 2 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => populateRows(array, spillThreshold * 2) val exception = @@ -205,10 +228,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with custom start index (without spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => - val expectedValues = populateRows(array, spillThreshold) - val startIndex = spillThreshold / 2 + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + val expectedValues = populateRows(array, inMemoryThreshold) + val startIndex = inMemoryThreshold / 2 val iterator = array.generateIterator(startIndex = startIndex) for (i <- startIndex until expectedValues.length) { checkIfValueExists(iterator, expectedValues(i)) @@ -217,8 +240,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with custom start index (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => val expectedValues = populateRows(array, spillThreshold * 10) val startIndex = spillThreshold * 2 val iterator = array.generateIterator(startIndex = startIndex) @@ -229,7 +252,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("test iterator invalidation (without spill)") { - withExternalArray(spillThreshold = 10) { array => + withExternalArray(inMemoryThreshold = 10, spillThreshold = 100) { array => // insert 2 rows, iterate until the first row populateRows(array, 2) @@ -254,9 +277,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("test iterator invalidation (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => - // Populate enough rows so that spill has happens + val (inMemoryThreshold, spillThreshold) = (2, 10) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + // Populate enough rows so that spill happens populateRows(array, spillThreshold * 2) assertSpill() @@ -281,7 +304,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("clear on an empty the array") { - withExternalArray(spillThreshold = 2) { array => + withExternalArray(inMemoryThreshold = 2, spillThreshold = 3) { array => val iterator = array.generateIterator() assert(!iterator.hasNext) @@ -299,10 +322,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("clear array (without spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (10, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => // Populate rows ... but not enough to trigger spill - populateRows(array, spillThreshold / 2) + populateRows(array, inMemoryThreshold / 2) assertNoSpill() // Clear the array @@ -311,21 +334,21 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar // Re-populate few rows so that there is no spill // Verify the data. Verify that there was no spill - val expectedValues = populateRows(array, spillThreshold / 3) + val expectedValues = populateRows(array, inMemoryThreshold / 2) validateData(array, expectedValues) assertNoSpill() // Populate more rows .. enough to not trigger a spill. // Verify the data. Verify that there was no spill - populateRows(array, spillThreshold / 3, expectedValues) + populateRows(array, inMemoryThreshold / 2, expectedValues) validateData(array, expectedValues) assertNoSpill() } } test("clear array (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (10, 20) + withExternalArray(inMemoryThreshold, spillThreshold) { array => // Populate enough rows to trigger spill populateRows(array, spillThreshold * 2) val bytesSpilled = getNumBytesSpilled diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 5c63c6a414f9..a3d75b221ec3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -35,39 +35,47 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { private var globalTempDB: String = _ test("basic semantic") { - sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") + + // If there is no database in table name, we should try local temp view first, if not found, + // try table/view in current database, which is "default" in this case. So we expect + // NoSuchTableException here. + intercept[NoSuchTableException](spark.table("src")) - // If there is no database in table name, we should try local temp view first, if not found, - // try table/view in current database, which is "default" in this case. So we expect - // NoSuchTableException here. - intercept[NoSuchTableException](spark.table("src")) + // Use qualified name to refer to the global temp view explicitly. + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) - // Use qualified name to refer to the global temp view explicitly. - checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + // Table name without database will never refer to a global temp view. + intercept[NoSuchTableException](sql("DROP VIEW src")) - // Table name without database will never refer to a global temp view. - intercept[NoSuchTableException](sql("DROP VIEW src")) + sql(s"DROP VIEW $globalTempDB.src") + // The global temp view should be dropped successfully. + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) - sql(s"DROP VIEW $globalTempDB.src") - // The global temp view should be dropped successfully. - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + // We can also use Dataset API to create global temp view + Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) - // We can also use Dataset API to create global temp view - Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") - checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + // Use qualified name to rename a global temp view. + sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) - // Use qualified name to rename a global temp view. - sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) - checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) + // Use qualified name to alter a global temp view. + sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'") + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b")) - // Use qualified name to alter a global temp view. - sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'") - checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b")) + // We can also use Catalog API to drop global temp view + spark.catalog.dropGlobalTempView("src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) - // We can also use Catalog API to drop global temp view - spark.catalog.dropGlobalTempView("src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) + // We can also use Dataset API to replace global temp view + Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b")) + } finally { + spark.catalog.dropGlobalTempView("src") + } } test("global temp view is shared among all sessions") { @@ -106,7 +114,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("CREATE TABLE LIKE should work for global temp view") { try { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") - sql(s"CREATE TABLE cloned LIKE ${globalTempDB}.src") + sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala index 58c310596ca6..78c1e5dae566 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala @@ -42,14 +42,14 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { private def assertMetadataOnlyQuery(df: DataFrame): Unit = { val localRelations = df.queryExecution.optimizedPlan.collect { - case l @ LocalRelation(_, _) => l + case l @ LocalRelation(_, _, _) => l } assert(localRelations.size == 1) } private def assertNotMetadataOnlyQuery(df: DataFrame): Unit = { val localRelations = df.queryExecution.optimizedPlan.collect { - case l @ LocalRelation(_, _) => l + case l @ LocalRelation(_, _, _) => l } assert(localRelations.size == 0) } @@ -117,4 +117,12 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { "select partcol1, max(partcol2) from srcpart where partcol1 = 0 group by rollup (partcol1)", "select partcol2 from (select partcol2 from srcpart where partcol1 = 0 union all " + "select partcol2 from srcpart where partcol1 = 1) t group by partcol2") + + test("SPARK-21884 Fix StackOverflowError on MetadataOnlyQuery") { + withTable("t_1000") { + sql("CREATE TABLE t_1000 (a INT, p INT) USING PARQUET PARTITIONED BY (p)") + (1 to 1000).foreach(p => sql(s"ALTER TABLE t_1000 ADD PARTITION (p=$p)")) + sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 4d155d538d63..86066362da9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -214,7 +214,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (small.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: ShuffleExchange => exchange + case exchange: ShuffleExchangeExec => exchange }.length assert(numExchanges === 5) } @@ -229,7 +229,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (normal.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: ShuffleExchange => exchange + case exchange: ShuffleExchangeExec => exchange }.length assert(numExchanges === 5) } @@ -300,7 +300,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -338,7 +338,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -358,7 +358,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -381,7 +381,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } @@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchange(finalPartitioning, + val inputPlan = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -400,7 +400,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") } } @@ -411,7 +411,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchange(finalPartitioning, + val inputPlan = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -420,7 +420,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") } } @@ -430,7 +430,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val shuffle = ShuffleExchange(finalPartitioning, + val shuffle = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -449,7 +449,7 @@ class PlannerSuite extends SharedSQLContext { if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } - if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size != 1) { fail(s"Should have only one shuffle:\n$outputPlan") } @@ -459,14 +459,14 @@ class PlannerSuite extends SharedSQLContext { Literal(1) :: Nil, Inner, None, - ShuffleExchange(finalPartitioning, inputPlan), - ShuffleExchange(finalPartitioning, inputPlan)) + ShuffleExchangeExec(finalPartitioning, inputPlan), + ShuffleExchangeExec(finalPartitioning, inputPlan)) val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } - if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) { + if (outputPlan2.collect { case e: ShuffleExchangeExec => true }.size != 2) { fail(s"Should have only two shuffles:\n$outputPlan") } } @@ -513,26 +513,30 @@ class PlannerSuite extends SharedSQLContext { } test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { - val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) - // Both left and right keys should be sorted after the SMJ. - Seq(orderingA, orderingB).foreach { ordering => - assertSortRequirementsAreSatisfied( - childPlan = innerSmj, - requiredOrdering = Seq(ordering), - shouldHaveSort = false) + Seq(Inner, Cross).foreach { joinType => + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) + // Both left and right keys should be sorted after the SMJ. + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = innerSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } } } test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + "child SMJ") { - val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) - val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC) - // After the second SMJ, exprA, exprB and exprC should all be sorted. - Seq(orderingA, orderingB, orderingC).foreach { ordering => - assertSortRequirementsAreSatisfied( - childPlan = parentSmj, - requiredOrdering = Seq(ordering), - shouldHaveSort = false) + Seq(Inner, Cross).foreach { joinType => + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, None, childSmj, planC) + // After the second SMJ, exprA, exprB and exprC should all be sorted. + Seq(orderingA, orderingB, orderingC).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = parentSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 1c1931b6a6da..964440346deb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -16,37 +16,36 @@ */ package org.apache.spark.sql.execution -import java.util.Locale - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext class QueryExecutionSuite extends SharedSQLContext { test("toString() exception/error handling") { - val badRule = new SparkStrategy { - var mode: String = "" - override def apply(plan: LogicalPlan): Seq[SparkPlan] = - mode.toLowerCase(Locale.ROOT) match { - case "exception" => throw new AnalysisException(mode) - case "error" => throw new Error(mode) - case _ => Nil - } - } - spark.experimental.extraStrategies = badRule :: Nil + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil + }) - def qe: QueryExecution = new QueryExecution(spark, OneRowRelation) + def qe: QueryExecution = new QueryExecution(spark, OneRowRelation()) // Nothing! - badRule.mode = "" assert(qe.toString.contains("OneRowRelation")) // Throw an AnalysisException - this should be captured. - badRule.mode = "exception" + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + throw new AnalysisException("exception") + }) assert(qe.toString.contains("org.apache.spark.sql.AnalysisException")) // Throw an Error - this should not be captured. - badRule.mode = "error" + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + throw new Error("error") + }) val error = intercept[Error](qe.toString) assert(error.getMessage.contains("error")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index fe78a7656883..f6b006b98edd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { test("concurrent query execution (SPARK-10548)") { - // Try to reproduce the issue with the old SparkContext val conf = new SparkConf() .setMaster("local[*]") .setAppName("test") - val badSparkContext = new BadSparkContext(conf) - try { - testConcurrentQueryExecution(badSparkContext) - fail("unable to reproduce SPARK-10548") - } catch { - case e: IllegalArgumentException => - assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) - } finally { - badSparkContext.stop() - } - - // Verify that the issue is fixed with the latest SparkContext val goodSparkContext = new SparkContext(conf) try { testConcurrentQueryExecution(goodSparkContext) @@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite { } } -/** - * A bad [[SparkContext]] that does not clone the inheritable thread local properties - * when passing them to children threads. - */ -private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { - protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) - override protected def initialValue(): Properties = new Properties() - } -} - object SQLExecutionSuite { @volatile var canProgress = false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala new file mode 100644 index 000000000000..c2e62b987e0c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.json4s.jackson.JsonMethods.parse + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart +import org.apache.spark.util.JsonProtocol + +class SQLJsonProtocolSuite extends SparkFunSuite { + + test("SparkPlanGraph backward compatibility: metadata") { + val SQLExecutionStartJsonString = + """ + |{ + | "Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", + | "executionId":0, + | "description":"test desc", + | "details":"test detail", + | "physicalPlanDescription":"test plan", + | "sparkPlanInfo": { + | "nodeName":"TestNode", + | "simpleString":"test string", + | "children":[], + | "metadata":{}, + | "metrics":[] + | }, + | "time":0 + |} + """.stripMargin + val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) + val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan", + new SparkPlanInfo("TestNode", "test string", Nil, Nil), 0) + assert(reconstructedEvent == expectedEvent) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index d32716c18ddf..6761f05bb462 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -669,4 +669,14 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { "positive.")) } } + + test("permanent view should be case-preserving") { + withView("v") { + sql("CREATE VIEW v AS SELECT 1 as aBc") + assert(spark.table("v").schema.head.name == "aBc") + + sql("CREATE OR REPLACE VIEW v AS SELECT 2 as cBa") + assert(spark.table("v").schema.head.name == "cBa") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index 52e4f047225d..1c6fc3530cbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution +import org.apache.spark.TestUtils.assertSpilled import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.TestUtils.assertSpilled case class WindowData(month: Int, area: String, product: Int) @@ -356,6 +356,46 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { spark.catalog.dropTempView("nums") } + test("window function: mutiple window expressions specified by range in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.createOrReplaceTempView("nums") + withTempView("nums") { + val expected = + Row(1, 1, 1, 4, null, 8, 25) :: + Row(1, 3, 4, 9, 1, 12, 24) :: + Row(1, 5, 9, 15, 4, 16, 21) :: + Row(1, 7, 16, 21, 8, 9, 16) :: + Row(1, 9, 25, 16, 12, null, 9) :: + Row(0, 2, 2, 6, null, 10, 30) :: + Row(0, 4, 6, 12, 2, 14, 28) :: + Row(0, 6, 12, 18, 6, 18, 24) :: + Row(0, 8, 20, 24, 10, 10, 18) :: + Row(0, 10, 30, 18, 14, null, 10) :: + Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) over w1 as history_sum, + | sum(x) over w2 as period_sum1, + | sum(x) over w3 as period_sum2, + | sum(x) over w4 as period_sum3, + | sum(x) over w5 as future_sum + |FROM nums + |WINDOW + | w1 AS (PARTITION BY y ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), + | w2 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING), + | w3 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 4 PRECEDING AND 2 PRECEDING ), + | w4 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 4 FOLLOWING), + | w5 AS (PARTITION BY y ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) + """.stripMargin + ) + checkAnswer(actual, expected) + } + } + test("SPARK-7595: Window will cause resolve failed with self join") { checkAnswer(sql( """ @@ -437,7 +477,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) """.stripMargin) - withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") { + withSQLConf("spark.sql.windowExec.buffer.in.memory.threshold" -> "1", + "spark.sql.windowExec.buffer.spill.threshold" -> "2") { assertSpilled(sparkContext, "test with low buffer spill threshold") { checkAnswer(actual, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala new file mode 100644 index 000000000000..aaf51b5b9011 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Tests for the sameResult function for [[SparkPlan]]s. + */ +class SameResultSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("FileSourceScanExec: different orders of data filters and partition filters") { + withTempPath { path => + val tmpDir = path.getCanonicalPath + spark.range(10) + .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d") + .write + .partitionBy("a", "b") + .parquet(tmpDir) + val df = spark.read.parquet(tmpDir) + // partition filters: a > 1 AND b < 9 + // data filters: c > 1 AND d < 9 + val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9")) + val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1")) + assert(plan1.sameResult(plan2)) + } + } + + private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = { + df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + .asInstanceOf[FileSourceScanExec] + } + + test("SPARK-20725: partial aggregate should behave correctly for sameResult") { + val df1 = spark.range(10).agg(sum($"id")) + val df2 = spark.range(10).agg(sum($"id")) + assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan)) + + val df3 = spark.range(10).agg(sumDistinct($"id")) + val df4 = spark.range(10).agg(sumDistinct($"id")) + assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala index aecfd3062147..5828f9783da4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala @@ -40,7 +40,7 @@ class SparkPlannerSuite extends SharedSQLContext { case Union(children) => planned += 1 UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil - case LocalRelation(output, data) => + case LocalRelation(output, data, _) => planned += 1 LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil case NeverPlanned => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 908b955abbf0..107a2f710979 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable, RefreshResource} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -36,7 +35,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType * See [[org.apache.spark.sql.catalyst.parser.PlanParserSuite]] for rules * defined in the Catalyst module. */ -class SparkSqlParserSuite extends PlanTest { +class SparkSqlParserSuite extends AnalysisTest { val newConf = new SQLConf private lazy val parser = new SparkSqlParser(newConf) @@ -67,6 +66,25 @@ class SparkSqlParserSuite extends PlanTest { } } + test("refresh resource") { + assertEqual("REFRESH prefix_path", RefreshResource("prefix_path")) + assertEqual("REFRESH /", RefreshResource("/")) + assertEqual("REFRESH /path///a", RefreshResource("/path///a")) + assertEqual("REFRESH pat1h/112/_1a", RefreshResource("pat1h/112/_1a")) + assertEqual("REFRESH pat1h/112/_1a/a-1", RefreshResource("pat1h/112/_1a/a-1")) + assertEqual("REFRESH path-with-dash", RefreshResource("path-with-dash")) + assertEqual("REFRESH \'path with space\'", RefreshResource("path with space")) + assertEqual("REFRESH \"path with space 2\"", RefreshResource("path with space 2")) + intercept("REFRESH a b", "REFRESH statements cannot contain") + intercept("REFRESH a\tb", "REFRESH statements cannot contain") + intercept("REFRESH a\nb", "REFRESH statements cannot contain") + intercept("REFRESH a\rb", "REFRESH statements cannot contain") + intercept("REFRESH a\r\nb", "REFRESH statements cannot contain") + intercept("REFRESH @ $a$", "REFRESH statements cannot contain") + intercept("REFRESH ", "Resource paths cannot be empty in REFRESH statements") + intercept("REFRESH", "Resource paths cannot be empty in REFRESH statements") + } + test("show functions") { assertEqual("show functions", ShowFunctionsCommand(None, None, true, true)) assertEqual("show all functions", ShowFunctionsCommand(None, None, true, true)) @@ -231,8 +249,34 @@ class SparkSqlParserSuite extends PlanTest { assertEqual("describe table formatted t", DescribeTableCommand( TableIdentifier("t"), Map.empty, isExtended = true)) + } + + test("describe table column") { + assertEqual("DESCRIBE t col", + DescribeColumnCommand( + TableIdentifier("t"), Seq("col"), isExtended = false)) + assertEqual("DESCRIBE t `abc.xyz`", + DescribeColumnCommand( + TableIdentifier("t"), Seq("abc.xyz"), isExtended = false)) + assertEqual("DESCRIBE t abc.xyz", + DescribeColumnCommand( + TableIdentifier("t"), Seq("abc", "xyz"), isExtended = false)) + assertEqual("DESCRIBE t `a.b`.`x.y`", + DescribeColumnCommand( + TableIdentifier("t"), Seq("a.b", "x.y"), isExtended = false)) - intercept("explain describe tables x", "Unsupported SQL statement") + assertEqual("DESCRIBE TABLE t col", + DescribeColumnCommand( + TableIdentifier("t"), Seq("col"), isExtended = false)) + assertEqual("DESCRIBE TABLE EXTENDED t col", + DescribeColumnCommand( + TableIdentifier("t"), Seq("col"), isExtended = true)) + assertEqual("DESCRIBE TABLE FORMATTED t col", + DescribeColumnCommand( + TableIdentifier("t"), Seq("col"), isExtended = true)) + + intercept("DESCRIBE TABLE t PARTITION (ds='1970-01-01') col", + "DESC TABLE COLUMN for a specific partition is not supported") } test("analyze table statistics") { @@ -241,17 +285,33 @@ class SparkSqlParserSuite extends PlanTest { assertEqual("analyze table t compute statistics noscan", AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) assertEqual("analyze table t partition (a) compute statistics nOscAn", - AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + AnalyzePartitionCommand(TableIdentifier("t"), Map("a" -> None), noscan = true)) - // Partitions specified - we currently parse them but don't do anything with it + // Partitions specified assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", - AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + AnalyzePartitionCommand(TableIdentifier("t"), noscan = false, + partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> Some("11")))) assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", - AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> Some("11")))) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09') COMPUTE STATISTICS noscan", + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> Some("2008-04-09")))) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr) COMPUTE STATISTICS", + AnalyzePartitionCommand(TableIdentifier("t"), noscan = false, + partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> None))) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr) COMPUTE STATISTICS noscan", + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> None))) + assertEqual("ANALYZE TABLE t PARTITION(ds, hr=11) COMPUTE STATISTICS noscan", + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> None, "hr" -> Some("11")))) assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS", - AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + AnalyzePartitionCommand(TableIdentifier("t"), noscan = false, + partitionSpec = Map("ds" -> None, "hr" -> None))) assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan", - AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> None, "hr" -> None))) intercept("analyze table t compute statistics xxxx", "Expected `NOSCAN` instead of `xxxx`") @@ -264,6 +324,11 @@ class SparkSqlParserSuite extends PlanTest { assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) + + // Partition specified - should be ignored + assertEqual("ANALYZE TABLE t PARTITION(ds='2017-06-10') " + + "COMPUTE STATISTICS FOR COLUMNS key, value", + AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) } test("query organization") { @@ -290,4 +355,15 @@ class SparkSqlParserSuite extends PlanTest { basePlan, numPartitions = newConf.numShufflePartitions))) } + + test("pipeline concatenation") { + val concat = Concat( + Concat(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil) :: + UnresolvedAttribute("c") :: + Nil + ) + assertEqual( + "SELECT a || b || c FROM t", + Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 6cf18de0cc76..d194f58cd1cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -111,8 +111,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 1024, // initial capacity, - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) assert(!map.iterator().next()) map.free() @@ -125,13 +124,13 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 1024, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) val groupKey = InternalRow(UTF8String.fromString("cats")) + val row = map.getAggregationBuffer(groupKey) // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) - assert(map.getAggregationBuffer(groupKey) != null) + assert(row != null) val iter = map.iterator() assert(iter.next()) iter.getKey.getString(0) should be ("cats") @@ -140,7 +139,7 @@ class UnsafeFixedWidthAggregationMapSuite // Modifications to rows retrieved from the map should update the values in the map iter.getValue.setInt(0, 42) - map.getAggregationBuffer(groupKey).getInt(0) should be (42) + row.getInt(0) should be (42) map.free() } @@ -152,8 +151,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) val rand = new Random(42) val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet @@ -178,8 +176,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) val keys = randomStrings(1024).take(512) @@ -226,8 +223,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) val sorter = map.destructAndCreateExternalSorter() @@ -267,8 +263,7 @@ class UnsafeFixedWidthAggregationMapSuite StructType(Nil), taskMemoryManager, 128, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) (1 to 10).foreach { i => val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) @@ -312,8 +307,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - pageSize, - false // disable perf metrics + pageSize ) val rand = new Random(42) @@ -350,8 +344,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - pageSize, - false // disable perf metrics + pageSize ) val rand = new Random(42) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 53105e0b2495..dff88ce7f1b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.Utils +import org.apache.spark.util.collection.ExternalSorter /** * used to test close InputStream in UnsafeRowSerializer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a4b30a2f8cec..beeee6a97c8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{Column, Dataset, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -127,4 +130,80 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { "named_struct('a',id+2, 'b',id+2) as col2") .filter("col1 = col2").count() } + + test("SPARK-21441 SortMergeJoin codegen with CodegenFallback expressions should be disabled") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { + import testImplicits._ + + val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int") + val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str") + + val df = df1.join(df2, df1("key") === df2("key")) + .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1") + .select("int") + + val plan = df.queryExecution.executedPlan + assert(!plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.children(0) + .isInstanceOf[SortMergeJoinExec]).isDefined) + assert(df.collect() === Array(Row(1), Row(2))) + } + } + + def genGroupByCodeGenContext(caseNum: Int): CodegenContext = { + val caseExp = (1 to caseNum).map { i => + s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i" + }.toList + val keyExp = List( + "id", + "(id & 1023) as k1", + "cast(id & 1023 as double) as k2", + "cast(id & 1023 as int) as k3") + + val ds = spark.range(10) + .selectExpr(keyExp:::caseExp: _*) + .groupBy("k1", "k2", "k3") + .sum() + val plan = ds.queryExecution.executedPlan + + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => wp.child match { + case hp: HashAggregateExec if (hp.child.isInstanceOf[ProjectExec]) => true + case _ => false + } + case _ => false + }) + + assert(wholeStageCodeGenExec.isDefined) + wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._1 + } + + test("SPARK-21603 check there is a too long generated function") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { + val ctx = genGroupByCodeGenContext(30) + assert(ctx.isTooLongGeneratedFunction === true) + } + } + + test("SPARK-21603 check there is not a too long generated function") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { + val ctx = genGroupByCodeGenContext(1) + assert(ctx.isTooLongGeneratedFunction === false) + } + } + + test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) { + val ctx = genGroupByCodeGenContext(30) + assert(ctx.isTooLongGeneratedFunction === false) + } + } + + test("SPARK-21603 check there is a too long generated function when threshold is 0") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "0") { + val ctx = genGroupByCodeGenContext(1) + assert(ctx.isTooLongGeneratedFunction === true) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index bc9cb6ec2e77..10f1ee279bed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -23,8 +23,8 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.unsafe.KVIterator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala new file mode 100644 index 000000000000..30422b657742 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -0,0 +1,1692 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.arrow + +import java.io.File +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat +import java.util.Locale + +import com.google.common.io.Files +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.util.Validator +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} +import org.apache.spark.util.Utils + + +class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { + import testImplicits._ + + private var tempDataPath: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath + } + + test("collect to arrow record batch") { + val indexData = (1 to 6).toDF("i") + val arrowPayloads = indexData.toArrowPayload.collect() + assert(arrowPayloads.nonEmpty) + assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val rowCount = arrowRecordBatches.map(_.getLength).sum + assert(rowCount === indexData.count()) + arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("short conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_s", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] + | }, { + | "name" : "b_s", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] + | } ] + | } ] + |} + """.stripMargin + + val a_s = List[Short](1, -1, 2, -2, 32767, -32768) + val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) + val df = a_s.zip(b_s).toDF("a_s", "b_s") + + collectAndValidate(df, json, "integer-16bit.json") + } + + test("int conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + collectAndValidate(df, json, "integer-32bit.json") + } + + test("long conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_l", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] + | }, { + | "name" : "b_l", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] + | } ] + | } ] + |} + """.stripMargin + + val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) + val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) + val df = a_l.zip(b_l).toDF("a_l", "b_l") + + collectAndValidate(df, json, "integer-64bit.json") + } + + test("float conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_f", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] + | }, { + | "name" : "b_f", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) + val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) + val df = a_f.zip(b_f).toDF("a_f", "b_f") + + collectAndValidate(df, json, "floating_point-single_precision.json") + } + + test("double conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] + | }, { + | "name" : "b_d", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) + val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "floating_point-double_precision.json") + } + + test("index conversion") { + val data = List[Int](1, 2, 3, 4, 5, 6) + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + val df = data.toDF("i") + + collectAndValidate(df, json, "indexData-ints.json") + } + + test("mixed numeric type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "e", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "b", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "c", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "e", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + + val data = List(1, 2, 3, 4, 5, 6) + val data_tuples = for (d <- data) yield { + (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) + } + val df = data_tuples.toDF("a", "b", "c", "d", "e") + + collectAndValidate(df, json, "mixed_numeric_types.json") + } + + test("string type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "upper_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "lower_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "null_str", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "upper_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "A", "B", "C" ] + | }, { + | "name" : "lower_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "a", "b", "c" ] + | }, { + | "name" : "null_str", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 0 ], + | "OFFSET" : [ 0, 2, 5, 5 ], + | "DATA" : [ "ab", "CDE", "" ] + | } ] + | } ] + |} + """.stripMargin + + val upperCase = Seq("A", "B", "C") + val lowerCase = Seq("a", "b", "c") + val nullStr = Seq("ab", "CDE", null) + val df = (upperCase, lowerCase, nullStr).zipped.toList + .toDF("upper_case", "lower_case", "null_str") + + collectAndValidate(df, json, "stringData.json") + } + + test("boolean type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_bool", + | "type" : { + | "name" : "bool" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_bool", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ true, true, false, true ] + | } ] + | } ] + |} + """.stripMargin + val df = Seq(true, true, false, true).toDF("a_bool") + collectAndValidate(df, json, "boolData.json") + } + + test("byte type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_byte", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 8 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_byte", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 64, 127 ] + | } ] + | } ] + |} + | + """.stripMargin + val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") + collectAndValidate(df, json, "byteData.json") + } + + test("binary type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_binary", + | "type" : { + | "name" : "binary" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_binary", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 3, 4, 6 ], + | "DATA" : [ "616263", "64", "6566" ] + | } ] + | } ] + |} + """.stripMargin + + val data = Seq("abc", "d", "ef") + val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) + val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) + + collectAndValidate(df, json, "binaryData.json") + } + + test("floating-point NaN") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "NaN_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "NaN_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 2, + | "columns" : [ { + | "name" : "NaN_f", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1.2000000476837158, "NaN" ] + | }, { + | "name" : "NaN_d", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ "NaN", 1.2 ] + | } ] + | } ] + |} + """.stripMargin + + val fnan = Seq(1.2F, Float.NaN) + val dnan = Seq(Double.NaN, 1.2) + val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") + + collectAndValidate(df, json, "nanData-floating_point.json") + } + + test("array type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 4, 4, 5 ], + | "children" : [ { + | "name" : "element", + | "count" : 5, + | "VALIDITY" : [ 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5 ] + | } ] + | }, { + | "name" : "b_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 0, 1, 0 ], + | "OFFSET" : [ 0, 2, 2, 2, 2 ], + | "children" : [ { + | "name" : "element", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1, 2 ] + | } ] + | }, { + | "name" : "c_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 4, 4, 5 ], + | "children" : [ { + | "name" : "element", + | "count" : 5, + | "VALIDITY" : [ 1, 1, 1, 0, 1 ], + | "DATA" : [ 1, 2, 3, 0, 5 ] + | } ] + | }, { + | "name" : "d_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 3, 3, 4 ], + | "children" : [ { + | "name" : "element", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 3, 3, 4 ], + | "children" : [ { + | "name" : "element", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 5 ] + | } ] + | } ] + | } ] + | } ] + |} + """.stripMargin + + val aArr = Seq(Seq(1, 2), Seq(3, 4), Seq(), Seq(5)) + val bArr = Seq(Some(Seq(1, 2)), None, Some(Seq()), None) + val cArr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5))) + val dArr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5))) + + val df = aArr.zip(bArr).zip(cArr).zip(dArr).map { + case (((a, b), c), d) => (a, b, c, d) + }.toDF("a_arr", "b_arr", "c_arr", "d_arr") + + collectAndValidate(df, json, "arrayData.json") + } + + test("struct type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_struct", + | "nullable" : false, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "b_struct", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "c_struct", + | "nullable" : false, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "d_struct", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "nested", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "b_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "c_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "d_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "children" : [ { + | "name" : "nested", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 0 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 0 ], + | "DATA" : [ 1, 2, 0 ] + | } ] + | } ] + | } ] + | } ] + |} + """.stripMargin + + val aStruct = Seq(Row(1), Row(2), Row(3)) + val bStruct = Seq(Row(1), null, Row(3)) + val cStruct = Seq(Row(1), Row(null), Row(3)) + val dStruct = Seq(Row(Row(1)), null, Row(null)) + val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map { + case (((a, b), c), d) => Row(a, b, c, d) + } + + val rdd = sparkContext.parallelize(data) + val schema = new StructType() + .add("a_struct", new StructType().add("i", IntegerType, nullable = false), nullable = false) + .add("b_struct", new StructType().add("i", IntegerType, nullable = false), nullable = true) + .add("c_struct", new StructType().add("i", IntegerType, nullable = true), nullable = false) + .add("d_struct", new StructType().add("nested", new StructType().add("i", IntegerType))) + val df = spark.createDataFrame(rdd, schema) + + collectAndValidate(df, json, "structData.json") + } + + test("partitioned DataFrame") { + val json1 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 1, 2 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 1 ] + | } ] + | } ] + |} + """.stripMargin + val json2 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 3, 3 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 1, 2 ] + | } ] + | } ] + |} + """.stripMargin + + val arrowPayloads = testData2.toArrowPayload.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowPayloads.length === 2) + val schema = testData2.schema + + val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") + val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") + Files.write(json1, tempFile1, StandardCharsets.UTF_8) + Files.write(json2, tempFile2, StandardCharsets.UTF_8) + + validateConversion(schema, arrowPayloads(0), tempFile1) + validateConversion(schema, arrowPayloads(1), tempFile2) + } + + test("empty frame collect") { + val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() + assert(arrowPayload.isEmpty) + + val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") + val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() + assert(filteredArrowPayload.isEmpty) + } + + test("empty partition collect") { + val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") + val arrowPayloads = emptyPart.toArrowPayload.collect() + assert(arrowPayloads.length === 1) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + assert(arrowRecordBatches.head.getLength == 1) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("max records in batch conf") { + val totalRecords = 10 + val maxRecordsPerBatch = 3 + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) + val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") + val arrowPayloads = df.toArrowPayload.collect() + assert(arrowPayloads.length >= 4) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + var recordCount = 0 + arrowRecordBatches.foreach { batch => + assert(batch.getLength > 0) + assert(batch.getLength <= maxRecordsPerBatch) + recordCount += batch.getLength + batch.close() + } + assert(recordCount == totalRecords) + allocator.close() + spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + } + + testQuietly("unsupported types") { + def runUnsupported(block: => Unit): Unit = { + val msg = intercept[SparkException] { + block + } + assert(msg.getMessage.contains("Unsupported data type")) + assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) + } + + runUnsupported { decimalData.toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowPayload.collect() } + runUnsupported { complexData.toArrowPayload.collect() } + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) + runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } + + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } + } + + test("test Arrow Validator") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + val json_diff_col_order = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + // Different schema + intercept[IllegalArgumentException] { + collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") + } + + // Different values + intercept[IllegalArgumentException] { + collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") + } + } + + test("roundtrip payloads") { + val inputRows = (0 until 9).map { i => + InternalRow(i) + } :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + + val ctx = TaskContext.empty() + val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx) + val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) + + assert(schema.equals(outputRowIter.schema)) + + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ + private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + // NOTE: coalesce to single partition because can only load 1 batch in validator + val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val tempFile = new File(tempDataPath, file) + Files.write(json, tempFile, StandardCharsets.UTF_8) + validateConversion(df.schema, arrowPayload, tempFile) + } + + private def validateConversion( + sparkSchema: StructType, + arrowPayload: ArrowPayload, + jsonFile: File): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val jsonReader = new JsonFileReader(jsonFile, allocator) + + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema) + val jsonSchema = jsonReader.start() + Validator.compareSchemas(arrowSchema, jsonSchema) + + val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) + val vectorLoader = new VectorLoader(arrowRoot) + val arrowRecordBatch = arrowPayload.loadBatch(allocator) + vectorLoader.load(arrowRecordBatch) + val jsonRoot = jsonReader.read() + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) + + jsonRoot.close() + jsonReader.close() + arrowRecordBatch.close() + arrowRoot.close() + allocator.close() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala new file mode 100644 index 000000000000..638619fd39d0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class ArrowUtilsSuite extends SparkFunSuite { + + def roundtrip(dt: DataType): Unit = { + dt match { + case schema: StructType => + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema)) === schema) + case _ => + roundtrip(new StructType().add("value", dt)) + } + } + + test("simple") { + roundtrip(BooleanType) + roundtrip(ByteType) + roundtrip(ShortType) + roundtrip(IntegerType) + roundtrip(LongType) + roundtrip(FloatType) + roundtrip(DoubleType) + roundtrip(StringType) + roundtrip(BinaryType) + roundtrip(DecimalType.SYSTEM_DEFAULT) + } + + test("array") { + roundtrip(ArrayType(IntegerType, containsNull = true)) + roundtrip(ArrayType(IntegerType, containsNull = false)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = true)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = false)) + } + + test("struct") { + roundtrip(new StructType()) + roundtrip(new StructType().add("i", IntegerType)) + roundtrip(new StructType().add("arr", ArrayType(IntegerType))) + roundtrip(new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType))) + roundtrip(new StructType().add( + "struct", + new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType)))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala new file mode 100644 index 000000000000..e9a629315f5f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.vectorized.ArrowColumnVector +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ArrowWriterSuite extends SparkFunSuite { + + test("simple") { + def check(dt: DataType, data: Seq[Any]): Unit = { + val schema = new StructType().add("value", dt, nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + data.zipWithIndex.foreach { + case (null, rowId) => assert(reader.isNullAt(rowId)) + case (datum, rowId) => + val value = dt match { + case BooleanType => reader.getBoolean(rowId) + case ByteType => reader.getByte(rowId) + case ShortType => reader.getShort(rowId) + case IntegerType => reader.getInt(rowId) + case LongType => reader.getLong(rowId) + case FloatType => reader.getFloat(rowId) + case DoubleType => reader.getDouble(rowId) + case StringType => reader.getUTF8String(rowId) + case BinaryType => reader.getBinary(rowId) + } + assert(value === datum) + } + + writer.root.close() + } + check(BooleanType, Seq(true, null, false)) + check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte)) + check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort)) + check(IntegerType, Seq(1, 2, null, 4)) + check(LongType, Seq(1L, 2L, null, 4L)) + check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) + check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) + check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) + } + + test("get multiple") { + def check(dt: DataType, data: Seq[Any]): Unit = { + val schema = new StructType().add("value", dt, nullable = false) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + val values = dt match { + case BooleanType => reader.getBooleans(0, data.size) + case ByteType => reader.getBytes(0, data.size) + case ShortType => reader.getShorts(0, data.size) + case IntegerType => reader.getInts(0, data.size) + case LongType => reader.getLongs(0, data.size) + case FloatType => reader.getFloats(0, data.size) + case DoubleType => reader.getDoubles(0, data.size) + } + assert(values === data) + + writer.root.close() + } + check(BooleanType, Seq(true, false)) + check(ByteType, (0 until 10).map(_.toByte)) + check(ShortType, (0 until 10).map(_.toShort)) + check(IntegerType, (0 until 10)) + check(LongType, (0 until 10).map(_.toLong)) + check(FloatType, (0 until 10).map(_.toFloat)) + check(DoubleType, (0 until 10).map(_.toDouble)) + } + + test("array") { + val schema = new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) + writer.write(InternalRow(ArrayData.toArrayData(Array(4, 5)))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty[Int]))) + writer.write(InternalRow(ArrayData.toArrayData(Array(6, null, 8)))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 3) + assert(array0.getInt(0) === 1) + assert(array0.getInt(1) === 2) + assert(array0.getInt(2) === 3) + + val array1 = reader.getArray(1) + assert(array1.numElements() === 2) + assert(array1.getInt(0) === 4) + assert(array1.getInt(1) === 5) + + assert(reader.isNullAt(2)) + + val array3 = reader.getArray(3) + assert(array3.numElements() === 0) + + val array4 = reader.getArray(4) + assert(array4.numElements() === 3) + assert(array4.getInt(0) === 6) + assert(array4.isNullAt(1)) + assert(array4.getInt(2) === 8) + + writer.root.close() + } + + test("nested array") { + val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array( + ArrayData.toArrayData(Array(1, 2, 3)), + ArrayData.toArrayData(Array(4, 5)), + null, + ArrayData.toArrayData(Array.empty[Int]), + ArrayData.toArrayData(Array(6, null, 8)))))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 5) + + val array00 = array0.getArray(0) + assert(array00.numElements() === 3) + assert(array00.getInt(0) === 1) + assert(array00.getInt(1) === 2) + assert(array00.getInt(2) === 3) + + val array01 = array0.getArray(1) + assert(array01.numElements() === 2) + assert(array01.getInt(0) === 4) + assert(array01.getInt(1) === 5) + + assert(array0.isNullAt(2)) + + val array03 = array0.getArray(3) + assert(array03.numElements() === 0) + + val array04 = array0.getArray(4) + assert(array04.numElements() === 3) + assert(array04.getInt(0) === 6) + assert(array04.isNullAt(1)) + assert(array04.getInt(2) === 8) + + assert(reader.isNullAt(1)) + + val array2 = reader.getArray(2) + assert(array2.numElements() === 0) + + writer.root.close() + } + + test("struct") { + val schema = new StructType() + .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) + writer.write(InternalRow(InternalRow(null, null))) + writer.write(InternalRow(null)) + writer.write(InternalRow(InternalRow(4, null))) + writer.write(InternalRow(InternalRow(null, UTF8String.fromString("str5")))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct0 = reader.getStruct(0, 2) + assert(struct0.getInt(0) === 1) + assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct1 = reader.getStruct(1, 2) + assert(struct1.isNullAt(0)) + assert(struct1.isNullAt(1)) + + assert(reader.isNullAt(2)) + + val struct3 = reader.getStruct(3, 2) + assert(struct3.getInt(0) === 4) + assert(struct3.isNullAt(1)) + + val struct4 = reader.getStruct(4, 2) + assert(struct4.isNullAt(0)) + assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) + + writer.root.close() + } + + test("nested struct") { + val schema = new StructType().add("struct", + new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) + writer.write(InternalRow(InternalRow(InternalRow(null, null)))) + writer.write(InternalRow(InternalRow(null))) + writer.write(InternalRow(null)) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + assert(struct00.getInt(0) === 1) + assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + assert(struct10.isNullAt(0)) + assert(struct10.isNullAt(1)) + + val struct2 = reader.getStruct(2, 1) + assert(struct2.isNullAt(0)) + + assert(reader.isNullAt(3)) + + writer.root.close() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 8a798fb44469..691fa9ac5e1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -301,6 +301,68 @@ class AggregateBenchmark extends BenchmarkBase { */ } + ignore("max function length of wholestagecodegen") { + val N = 20 << 15 + + val benchmark = new Benchmark("max function length of wholestagecodegen", N) + def f(): Unit = sparkSession.range(N) + .selectExpr( + "id", + "(id & 1023) as k1", + "cast(id & 1023 as double) as k2", + "cast(id & 1023 as int) as k3", + "case when id > 100 and id <= 200 then 1 else 0 end as v1", + "case when id > 200 and id <= 300 then 1 else 0 end as v2", + "case when id > 300 and id <= 400 then 1 else 0 end as v3", + "case when id > 400 and id <= 500 then 1 else 0 end as v4", + "case when id > 500 and id <= 600 then 1 else 0 end as v5", + "case when id > 600 and id <= 700 then 1 else 0 end as v6", + "case when id > 700 and id <= 800 then 1 else 0 end as v7", + "case when id > 800 and id <= 900 then 1 else 0 end as v8", + "case when id > 900 and id <= 1000 then 1 else 0 end as v9", + "case when id > 1000 and id <= 1100 then 1 else 0 end as v10", + "case when id > 1100 and id <= 1200 then 1 else 0 end as v11", + "case when id > 1200 and id <= 1300 then 1 else 0 end as v12", + "case when id > 1300 and id <= 1400 then 1 else 0 end as v13", + "case when id > 1400 and id <= 1500 then 1 else 0 end as v14", + "case when id > 1500 and id <= 1600 then 1 else 0 end as v15", + "case when id > 1600 and id <= 1700 then 1 else 0 end as v16", + "case when id > 1700 and id <= 1800 then 1 else 0 end as v17", + "case when id > 1800 and id <= 1900 then 1 else 0 end as v18") + .groupBy("k1", "k2", "k3") + .sum() + .collect() + + benchmark.addCase(s"codegen = F") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T maxLinesPerFunction = 10000") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "10000") + f() + } + + benchmark.addCase(s"codegen = T maxLinesPerFunction = 1500") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "1500") + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1 + Intel64 Family 6 Model 58 Stepping 9, GenuineIntel + max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ---------------------------------------------------------------------------------------------- + codegen = F 462 / 533 1.4 704.4 1.0X + codegen = T maxLinesPerFunction = 10000 3444 / 3447 0.2 5255.3 0.1X + codegen = T maxLinesPerFunction = 1500 447 / 478 1.5 682.1 1.0X + */ + } + ignore("cube") { val N = 5 << 20 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index 46db41a8abad..5a25d7230837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.benchmark +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.IntegerType @@ -35,7 +36,9 @@ class JoinBenchmark extends BenchmarkBase { val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("Join w long", N) { - sparkSession.range(N).join(dim, (col("id") % M) === col("k")).count() + val df = sparkSession.range(N).join(dim, (col("id") % M) === col("k")) + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + df.count() } /* @@ -55,7 +58,9 @@ class JoinBenchmark extends BenchmarkBase { val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("Join w long duplicated", N) { val dim = broadcast(sparkSession.range(M).selectExpr("cast(id/10 as long) as k")) - sparkSession.range(N).join(dim, (col("id") % M) === col("k")).count() + val df = sparkSession.range(N).join(dim, (col("id") % M) === col("k")) + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + df.count() } /* @@ -75,9 +80,11 @@ class JoinBenchmark extends BenchmarkBase { .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) runBenchmark("Join w 2 ints", N) { - sparkSession.range(N).join(dim2, + val df = sparkSession.range(N).join(dim2, (col("id") % M).cast(IntegerType) === col("k1") - && (col("id") % M).cast(IntegerType) === col("k2")).count() + && (col("id") % M).cast(IntegerType) === col("k2")) + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + df.count() } /* @@ -97,9 +104,10 @@ class JoinBenchmark extends BenchmarkBase { .selectExpr("id as k1", "id as k2", "cast(id as string) as v")) runBenchmark("Join w 2 longs", N) { - sparkSession.range(N).join(dim3, + val df = sparkSession.range(N).join(dim3, (col("id") % M) === col("k1") && (col("id") % M) === col("k2")) - .count() + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + df.count() } /* @@ -119,9 +127,10 @@ class JoinBenchmark extends BenchmarkBase { .selectExpr("cast(id/10 as long) as k1", "cast(id/10 as long) as k2")) runBenchmark("Join w 2 longs duplicated", N) { - sparkSession.range(N).join(dim4, + val df = sparkSession.range(N).join(dim4, (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) - .count() + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + df.count() } /* @@ -138,7 +147,9 @@ class JoinBenchmark extends BenchmarkBase { val M = 1 << 16 val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("outer join w long", N) { - sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "left").count() + val df = sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "left") + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + df.count() } /* @@ -156,7 +167,9 @@ class JoinBenchmark extends BenchmarkBase { val M = 1 << 16 val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("semi join w long", N) { - sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count() + val df = sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi") + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + df.count() } /* @@ -174,7 +187,9 @@ class JoinBenchmark extends BenchmarkBase { runBenchmark("merge join", N) { val df1 = sparkSession.range(N).selectExpr(s"id * 2 as k1") val df2 = sparkSession.range(N).selectExpr(s"id * 3 as k2") - df1.join(df2, col("k1") === col("k2")).count() + val df = df1.join(df2, col("k1") === col("k2")) + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) + df.count() } /* @@ -193,7 +208,9 @@ class JoinBenchmark extends BenchmarkBase { .selectExpr(s"(id * 15485863) % ${N*10} as k1") val df2 = sparkSession.range(N) .selectExpr(s"(id * 15485867) % ${N*10} as k2") - df1.join(df2, col("k1") === col("k2")).count() + val df = df1.join(df2, col("k1") === col("k2")) + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) + df.count() } /* @@ -212,18 +229,19 @@ class JoinBenchmark extends BenchmarkBase { sparkSession.conf.set("spark.sql.join.preferSortMergeJoin", "false") runBenchmark("shuffle hash join", N) { val df1 = sparkSession.range(N).selectExpr(s"id as k1") - val df2 = sparkSession.range(N / 5).selectExpr(s"id * 3 as k2") - df1.join(df2, col("k1") === col("k2")).count() + val df2 = sparkSession.range(N / 3).selectExpr(s"id * 3 as k2") + val df = df1.join(df2, col("k1") === col("k2")) + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[ShuffledHashJoinExec]).isDefined) + df.count() } /* - *Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 - *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Windows 7 6.1 + *Intel64 Family 6 Model 94 Stepping 3, GenuineIntel *shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative *------------------------------------------------------------------------------------------- - *shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X - *shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X + *shuffle hash join codegen=false 2005 / 2010 2.1 478.0 1.0X + *shuffle hash join codegen=true 1773 / 1792 2.4 422.7 1.1X */ } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 239822b72034..69247d7f4e9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -17,23 +17,21 @@ package org.apache.spark.sql.execution.benchmark -import java.io.File - import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.util.Benchmark /** * Benchmark to measure TPCDS query performance. * To run this: - * spark-submit --class --jars + * spark-submit --class --data-location */ -object TPCDSQueryBenchmark { +object TPCDSQueryBenchmark extends Logging { val conf = new SparkConf() .setMaster("local[1]") @@ -43,6 +41,7 @@ object TPCDSQueryBenchmark { .set("spark.driver.memory", "3g") .set("spark.executor.memory", "3g") .set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString) + .set("spark.sql.crossJoin.enabled", "true") val spark = SparkSession.builder.config(conf).getOrCreate() @@ -60,32 +59,21 @@ object TPCDSQueryBenchmark { } def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - require(dataLocation.nonEmpty, - "please modify the value of dataLocation to point to your local TPCDS data") val tableSizes = setupTables(dataLocation) queries.foreach { name => - val queryString = fileToString(new File(Thread.currentThread().getContextClassLoader - .getResource(s"tpcds/$name.sql").getFile)) + val queryString = resourceToString(s"tpcds/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the - // logical plan and adding up the sizes of all tables that appear in the plan. Note that this - // currently doesn't take WITH subqueries into account which might lead to fairly inaccurate - // per-row processing time for those cases. + // logical plan and adding up the sizes of all tables that appear in the plan. val queryRelations = scala.collection.mutable.HashSet[String]() - spark.sql(queryString).queryExecution.logical.map { - case ur @ UnresolvedRelation(t: TableIdentifier) => - queryRelations.add(t.table) - case lp: LogicalPlan => - lp.expressions.foreach { _ foreach { - case subquery: SubqueryExpression => - subquery.plan.foreach { - case ur @ UnresolvedRelation(t: TableIdentifier) => - queryRelations.add(t.table) - case _ => - } - case _ => - } - } + spark.sql(queryString).queryExecution.analyzed.foreach { + case SubqueryAlias(alias, _: LogicalRelation) => + queryRelations.add(alias) + case LogicalRelation(_, _, Some(catalogTable), _) => + queryRelations.add(catalogTable.identifier.table) + case HiveTableRelation(tableMeta, _, _) => + queryRelations.add(tableMeta.identifier.table) case _ => } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum @@ -93,11 +81,14 @@ object TPCDSQueryBenchmark { benchmark.addCase(name) { i => spark.sql(queryString).collect() } + logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") benchmark.run() + logInfo(s"\n\n===== FINISHED $name =====\n") } } def main(args: Array[String]): Unit = { + val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) // List of all TPC-DS queries val tpcdsQueries = Seq( @@ -112,12 +103,20 @@ object TPCDSQueryBenchmark { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") - // In order to run this benchmark, please follow the instructions at - // https://github.com/databricks/spark-sql-perf/blob/master/README.md to generate the TPCDS data - // locally (preferably with a scale factor of 5 for benchmarking). Thereafter, the value of - // dataLocation below needs to be set to the location where the generated data is stored. - val dataLocation = "" + // If `--query-filter` defined, filters the queries that this option selects + val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { + val queries = tpcdsQueries.filter { case queryName => + benchmarkArgs.queryFilter.contains(queryName) + } + if (queries.isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") + } + queries + } else { + tpcdsQueries + } - tpcdsAll(dataLocation, queries = tpcdsQueries) + tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala new file mode 100644 index 000000000000..184ffff94298 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.util.Locale + + +class TPCDSQueryBenchmarkArguments(val args: Array[String]) { + var dataLocation: String = null + var queryFilter: Set[String] = Set.empty + + parseArgs(args.toList) + validateArguments() + + private def optionMatch(optionName: String, s: String): Boolean = { + optionName == s.toLowerCase(Locale.ROOT) + } + + private def parseArgs(inputArgs: List[String]): Unit = { + var args = inputArgs + + while (args.nonEmpty) { + args match { + case optName :: value :: tail if optionMatch("--data-location", optName) => + dataLocation = value + args = tail + + case optName :: value :: tail if optionMatch("--query-filter", optName) => + queryFilter = value.toLowerCase(Locale.ROOT).split(",").map(_.trim).toSet + args = tail + + case _ => + // scalastyle:off println + System.err.println("Unknown/unsupported param " + args) + // scalastyle:on println + printUsageAndExit(1) + } + } + } + + private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off + System.err.println(""" + |Usage: spark-submit --class [Options] + |Options: + | --data-location Path to TPCDS data + | --query-filter Queries to filter, e.g., q3,q5,q13 + | + |------------------------------------------------------------------------------------------------------------------ + |In order to run this benchmark, please follow the instructions at + |https://github.com/databricks/spark-sql-perf/blob/master/README.md + |to generate the TPCDS data locally (preferably with a scale factor of 5 for benchmarking). + |Thereafter, the value of needs to be set to the location where the generated data is stored. + """.stripMargin) + // scalastyle:on + System.exit(exitCode) + } + + private def validateArguments(): Unit = { + if (dataLocation == null) { + // scalastyle:off println + System.err.println("Must specify a data location") + // scalastyle:on println + printUsageAndExit(-1) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index b2d04f7c5a6e..d4e7e362c6c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -18,33 +18,29 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, - createRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) - testDecimalColumnStats(createRow(null, null, 0)) - - def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0)) + testDecimalColumnStats(Array(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: Array[Any]): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } } @@ -60,11 +56,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -73,14 +69,14 @@ class ColumnStatsSuite extends SparkFunSuite { } def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: Array[Any]): Unit = { val columnStatsName = classOf[DecimalColumnStats].getSimpleName val columnType = COMPACT_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new DecimalColumnStats(15, 10) - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } } @@ -96,11 +92,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 5f2a3aaff634..ff05049551dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -144,4 +144,18 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { ColumnType(DecimalType(19, 0)) } } + + test("show type name in type mismatch error") { + val invalidType = new DataType { + override def defaultSize: Int = 1 + override private[spark] def asNullable: DataType = this + override def typeName: String = "invalid type name" + } + + val message = intercept[java.lang.Exception] { + ColumnType(invalidType) + }.getMessage + + assert(message.contains("Unsupported type: invalid type name")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 109b1d9db60d..8d411eb191cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { .toDF().createOrReplaceTempView("sizeTst") spark.catalog.cacheTable("sizeTst") assert( - spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes > + spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes > spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala similarity index 56% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 8a6bc62fec96..fa5172ca8a3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -22,19 +22,26 @@ import java.util.Locale import scala.reflect.{classTag, ClassTag} +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.{Generate, InsertIntoDir, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Project, ScriptTransformation} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -// TODO: merge this with DDLSuite (SPARK-14441) -class DDLCommandSuite extends PlanTest { +class DDLParserSuite extends PlanTest with SharedSQLContext { private lazy val parser = new SparkSqlParser(new SQLConf) private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { @@ -56,6 +63,17 @@ class DDLCommandSuite extends PlanTest { } } + private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { + val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) + comparePlans(plan, expected, checkAnalysis = false) + } + + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { + parser.parsePlan(sql).collect { + case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + }.head + } + test("create database") { val sql = """ @@ -181,8 +199,29 @@ class DDLCommandSuite extends PlanTest { |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', |FILE '/path/to/file' """.stripMargin + val sql3 = + """ + |CREATE OR REPLACE TEMPORARY FUNCTION helloworld3 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val sql4 = + """ + |CREATE OR REPLACE FUNCTION hello.world1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', + |FILE '/path/to/file' + """.stripMargin + val sql5 = + """ + |CREATE FUNCTION IF NOT EXISTS hello.world2 as + |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', + |FILE '/path/to/file' + """.stripMargin val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) + val parsed3 = parser.parsePlan(sql3) + val parsed4 = parser.parsePlan(sql4) + val parsed5 = parser.parsePlan(sql5) val expected1 = CreateFunctionCommand( None, "helloworld", @@ -190,7 +229,7 @@ class DDLCommandSuite extends PlanTest { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true) + isTemp = true, ifNotExists = false, replace = false) val expected2 = CreateFunctionCommand( Some("hello"), "world", @@ -198,9 +237,36 @@ class DDLCommandSuite extends PlanTest { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false) + isTemp = false, ifNotExists = false, replace = false) + val expected3 = CreateFunctionCommand( + None, + "helloworld3", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), + FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), + isTemp = true, ifNotExists = false, replace = true) + val expected4 = CreateFunctionCommand( + Some("hello"), + "world1", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), + FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), + isTemp = false, ifNotExists = false, replace = true) + val expected5 = CreateFunctionCommand( + Some("hello"), + "world2", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), + FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), + isTemp = false, ifNotExists = true, replace = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + comparePlans(parsed5, expected5) } test("drop function") { @@ -408,6 +474,26 @@ class DDLCommandSuite extends PlanTest { } } + test("create table - with table properties") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet"), + properties = Map("test" -> "test")) + + parser.parsePlan(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -439,6 +525,55 @@ class DDLCommandSuite extends PlanTest { assert(e.message.contains("you can only specify one of them.")) } + test("insert overwrite directory") { + val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a" + parser.parsePlan(v1) match { + case InsertIntoDir(_, storage, provider, query, overwrite) => + assert(storage.locationUri.isDefined && storage.locationUri.get.toString == "/tmp/file") + case other => + fail(s"Expected to parse ${classOf[InsertIntoDataSourceDirCommand].getClass.getName}" + + " from query," + s" got ${other.getClass.getName}: $v1") + } + + val v2 = "INSERT OVERWRITE DIRECTORY USING parquet SELECT 1 as a" + val e2 = intercept[ParseException] { + parser.parsePlan(v2) + } + assert(e2.message.contains( + "Directory path and 'path' in OPTIONS should be specified one, but not both")) + + val v3 = + """ + | INSERT OVERWRITE DIRECTORY USING json + | OPTIONS ('path' '/tmp/file', a 1, b 0.1, c TRUE) + | SELECT 1 as a + """.stripMargin + parser.parsePlan(v3) match { + case InsertIntoDir(_, storage, provider, query, overwrite) => + assert(storage.locationUri.isDefined && provider == Some("json")) + assert(storage.properties.get("a") == Some("1")) + assert(storage.properties.get("b") == Some("0.1")) + assert(storage.properties.get("c") == Some("true")) + assert(!storage.properties.contains("abc")) + assert(!storage.properties.contains("path")) + case other => + fail(s"Expected to parse ${classOf[InsertIntoDataSourceDirCommand].getClass.getName}" + + " from query," + s"got ${other.getClass.getName}: $v1") + } + + val v4 = + """ + | INSERT OVERWRITE DIRECTORY '/tmp/file' USING json + | OPTIONS ('path' '/tmp/file', a 1, b 0.1, c TRUE) + | SELECT 1 as a + """.stripMargin + val e4 = intercept[ParseException] { + parser.parsePlan(v4) + } + assert(e4.message.contains( + "Directory path and 'path' in OPTIONS should be specified one, but not both")) + } + // ALTER TABLE table_name RENAME TO new_table_name; // ALTER VIEW view_name RENAME TO new_view_name; test("alter table/view: rename table/view") { @@ -998,4 +1133,553 @@ class DDLCommandSuite extends PlanTest { s"got ${other.getClass.getName}: $sql") } } + + test("Test CTAS #1") { + val s1 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |COMMENT 'This is the staging page view table' + |STORED AS RCFILE + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s1) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + + test("Test CTAS #2") { + val s2 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |COMMENT 'This is the staging page view table' + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s2) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + + test("Test CTAS #3") { + val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" + val (desc, exists) = extractTableDesc(s3) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.storage.locationUri == None) + assert(desc.schema.isEmpty) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc.properties == Map()) + } + + test("Test CTAS #4") { + val s4 = + """CREATE TABLE page_view + |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin + intercept[AnalysisException] { + extractTableDesc(s4) + } + } + + test("Test CTAS #5") { + val s5 = """CREATE TABLE ctas2 + | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + | STORED AS RCFile + | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + | AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin + val (desc, exists) = extractTableDesc(s5) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "ctas2") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.storage.locationUri == None) + assert(desc.schema.isEmpty) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.properties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) + } + + test("CTAS statement with a PARTITIONED BY clause is not allowed") { + assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + + " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") + } + + test("CTAS statement with schema") { + assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") + assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") + } + + test("unsupported operations") { + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TEMPORARY TABLE ctas2 + |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + |STORED AS RCFile + |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |CLUSTERED BY(user_id) INTO 256 BUCKETS + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |SKEWED BY (key) ON (1,5,6) + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe' + |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader' + |FROM testData + """.stripMargin) + } + } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + parser.parsePlan(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } + + test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { + val analyzer = spark.sessionState.analyzer + val plan = analyzer.execute(parser.parsePlan( + """ + |SELECT * + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin)) + + assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) + } + + test("transform query spec") { + val p = ScriptTransformation( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + "func", Seq.empty, plans.table("e"), null) + + compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + compareTransformQuery("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("use backticks in output of Script Transform") { + parser.parsePlan( + """SELECT `t`.`thing1` + |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) + |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t + """.stripMargin) + } + + test("use backticks in output of Generator") { + parser.parsePlan( + """ + |SELECT `gentab2`.`gencol2` + |FROM `default`.`src` + |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1` + |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2` + """.stripMargin) + } + + test("use escaped backticks in output of Generator") { + parser.parsePlan( + """ + |SELECT `gen``tab2`.`gen``col2` + |FROM `default`.`src` + |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1` + |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2` + """.stripMargin) + } + + test("create table - basic") { + val query = "CREATE TABLE my_table (id int, name string)" + val (desc, allowExisting) = extractTableDesc(query) + assert(!allowExisting) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.schema == new StructType().add("id", "int").add("name", "string")) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.bucketSpec.isEmpty) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.locationUri.isEmpty) + assert(desc.storage.inputFormat == + Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc.storage.properties.isEmpty) + assert(desc.properties.isEmpty) + assert(desc.comment.isEmpty) + } + + test("create table - with database name") { + val query = "CREATE TABLE dbx.my_table (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + } + + test("create table - temporary") { + val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" + val e = intercept[ParseException] { parser.parsePlan(query) } + assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet")) + } + + test("create table - external") { + val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" + val (desc, _) = extractTableDesc(query) + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) + } + + test("create table - if not exists") { + val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" + val (_, allowExisting) = extractTableDesc(query) + assert(allowExisting) + } + + test("create table - comment") { + val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" + val (desc, _) = extractTableDesc(query) + assert(desc.comment == Some("its hot as hell below")) + } + + test("create table - partitioned columns") { + val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" + val (desc, _) = extractTableDesc(query) + assert(desc.schema == new StructType() + .add("id", "int") + .add("name", "string") + .add("month", "int")) + assert(desc.partitionColumnNames == Seq("month")) + } + + test("create table - clustered by") { + val numBuckets = 10 + val bucketedColumn = "id" + val sortColumn = "id" + val baseQuery = + s""" + CREATE TABLE my_table ( + $bucketedColumn int, + name string) + CLUSTERED BY($bucketedColumn) + """ + + val query1 = s"$baseQuery INTO $numBuckets BUCKETS" + val (desc1, _) = extractTableDesc(query1) + assert(desc1.bucketSpec.isDefined) + val bucketSpec1 = desc1.bucketSpec.get + assert(bucketSpec1.numBuckets == numBuckets) + assert(bucketSpec1.bucketColumnNames.head.equals(bucketedColumn)) + assert(bucketSpec1.sortColumnNames.isEmpty) + + val query2 = s"$baseQuery SORTED BY($sortColumn) INTO $numBuckets BUCKETS" + val (desc2, _) = extractTableDesc(query2) + assert(desc2.bucketSpec.isDefined) + val bucketSpec2 = desc2.bucketSpec.get + assert(bucketSpec2.numBuckets == numBuckets) + assert(bucketSpec2.bucketColumnNames.head.equals(bucketedColumn)) + assert(bucketSpec2.sortColumnNames.head.equals(sortColumn)) + } + + test("create table - skewed by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + val e3 = intercept[ParseException] { parser.parsePlan(query3) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + assert(e3.getMessage.contains("Operation not allowed")) + } + + test("create table - row format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" + val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" + val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" + val query3 = + s""" + |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' + |COLLECTION ITEMS TERMINATED BY 'a' + |MAP KEYS TERMINATED BY 'b' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'c' + """.stripMargin + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + val (desc3, _) = extractTableDesc(query3) + assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc1.storage.properties.isEmpty) + assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc2.storage.properties == Map("k1" -> "v1")) + assert(desc3.storage.properties == Map( + "field.delim" -> "x", + "escape.delim" -> "y", + "serialization.format" -> "x", + "line.delim" -> "\n", + "colelction.delim" -> "a", // yes, it's a typo from Hive :) + "mapkey.delim" -> "b")) + } + + test("create table - file format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" + val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" + val query2 = s"$baseQuery ORC" + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + assert(desc1.storage.inputFormat == Some("winput")) + assert(desc1.storage.outputFormat == Some("wowput")) + assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + test("create table - storage handler") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - properties") { + val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" + val (desc, _) = extractTableDesc(query) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + } + + test("create table - everything!") { + val query = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) + |COMMENT 'no comment' + |PARTITIONED BY (month int) + |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') + |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' + |LOCATION '/path/to/mercury' + |TBLPROPERTIES ('k1'='v1', 'k2'='v2') + """.stripMargin + val (desc, allowExisting) = extractTableDesc(query) + assert(allowExisting) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.schema == new StructType() + .add("id", "int") + .add("name", "string") + .add("month", "int")) + assert(desc.partitionColumnNames == Seq("month")) + assert(desc.bucketSpec.isEmpty) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) + assert(desc.storage.inputFormat == Some("winput")) + assert(desc.storage.outputFormat == Some("wowput")) + assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc.storage.properties == Map("k1" -> "v1")) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + assert(desc.comment == Some("no comment")) + } + + test("create view -- basic") { + val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" + val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand] + assert(!command.allowExisting) + assert(command.name.database.isEmpty) + assert(command.name.table == "view1") + assert(command.originalText == Some("SELECT * FROM tab1")) + assert(command.userSpecifiedColumns.isEmpty) + } + + test("create view - full") { + val v1 = + """ + |CREATE OR REPLACE VIEW view1 + |(col1, col3 COMMENT 'hello') + |COMMENT 'BLABLA' + |TBLPROPERTIES('prop1Key'="prop1Val") + |AS SELECT * FROM tab1 + """.stripMargin + val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand] + assert(command.name.database.isEmpty) + assert(command.name.table == "view1") + assert(command.userSpecifiedColumns == Seq("col1" -> None, "col3" -> Some("hello"))) + assert(command.originalText == Some("SELECT * FROM tab1")) + assert(command.properties == Map("prop1Key" -> "prop1Val")) + assert(command.comment == Some("BLABLA")) + } + + test("create view -- partitioned view") { + val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart" + intercept[ParseException] { + parser.parsePlan(v1) + } + } + + test("MSCK REPAIR table") { + val sql = "MSCK REPAIR TABLE tab1" + val parsed = parser.parsePlan(sql) + val expected = AlterTableRecoverPartitionsCommand( + TableIdentifier("tab1", None), + "MSCK REPAIR TABLE") + comparePlans(parsed, expected) + } + + test("create table like") { + val v1 = "CREATE TABLE table1 LIKE table2" + val (target, source, location, exists) = parser.parsePlan(v1).collect { + case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) + }.head + assert(exists == false) + assert(target.database.isEmpty) + assert(target.table == "table1") + assert(source.database.isEmpty) + assert(source.table == "table2") + assert(location.isEmpty) + + val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" + val (target2, source2, location2, exists2) = parser.parsePlan(v2).collect { + case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) + }.head + assert(exists2) + assert(target2.database.isEmpty) + assert(target2.table == "table1") + assert(source2.database.isEmpty) + assert(source2.table == "table2") + assert(location2.isEmpty) + + val v3 = "CREATE TABLE table1 LIKE table2 LOCATION '/spark/warehouse'" + val (target3, source3, location3, exists3) = parser.parsePlan(v3).collect { + case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) + }.head + assert(!exists3) + assert(target3.database.isEmpty) + assert(target3.table == "table1") + assert(source3.database.isEmpty) + assert(source3.table == "table2") + assert(location3 == Some("/spark/warehouse")) + + val v4 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2 LOCATION '/spark/warehouse'" + val (target4, source4, location4, exists4) = parser.parsePlan(v4).collect { + case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) + }.head + assert(exists4) + assert(target4.database.isEmpty) + assert(target4.table == "table1") + assert(source4.database.isEmpty) + assert(source4.table == "table2") + assert(location4 == Some("/spark/warehouse")) + } + + test("load data") { + val v1 = "LOAD DATA INPATH 'path' INTO TABLE table1" + val (table, path, isLocal, isOverwrite, partition) = parser.parsePlan(v1).collect { + case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) + }.head + assert(table.database.isEmpty) + assert(table.table == "table1") + assert(path == "path") + assert(!isLocal) + assert(!isOverwrite) + assert(partition.isEmpty) + + val v2 = "LOAD DATA LOCAL INPATH 'path' OVERWRITE INTO TABLE table1 PARTITION(c='1', d='2')" + val (table2, path2, isLocal2, isOverwrite2, partition2) = parser.parsePlan(v2).collect { + case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) + }.head + assert(table2.database.isEmpty) + assert(table2.table == "table1") + assert(path2 == "path") + assert(isLocal2) + assert(isOverwrite2) + assert(partition2.nonEmpty) + assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 2f4eb1b15519..4ed2cecc5faf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -49,7 +49,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() @@ -67,49 +68,10 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo provider = Some("parquet"), partitionColumnNames = Seq("a", "b"), createTime = 0L, + createVersion = org.apache.spark.SPARK_VERSION, tracksPartitionsInCatalog = true) } - test("alter table: set location (datasource table)") { - testSetLocation(isDatasourceTable = true) - } - - test("alter table: set properties (datasource table)") { - testSetProperties(isDatasourceTable = true) - } - - test("alter table: unset properties (datasource table)") { - testUnsetProperties(isDatasourceTable = true) - } - - test("alter table: set serde (datasource table)") { - testSetSerde(isDatasourceTable = true) - } - - test("alter table: set serde partition (datasource table)") { - testSetSerdePartition(isDatasourceTable = true) - } - - test("alter table: change column (datasource table)") { - testChangeColumn(isDatasourceTable = true) - } - - test("alter table: add partition (datasource table)") { - testAddPartitions(isDatasourceTable = true) - } - - test("alter table: drop partition (datasource table)") { - testDropPartitions(isDatasourceTable = true) - } - - test("alter table: rename partition (datasource table)") { - testRenamePartitions(isDatasourceTable = true) - } - - test("drop table - data source table") { - testDropTable(isDatasourceTable = true) - } - test("create a managed Hive source table") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") val tabName = "tbl" @@ -163,7 +125,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive" } - protected def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable + protected def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -205,8 +170,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { ignoreIfExists = false) } - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { - catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) + private def createTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): Unit = { + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( @@ -223,6 +191,46 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri } + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: rename partition (datasource table)") { + testRenamePartitions(isDatasourceTable = true) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -429,16 +437,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create table - duplicate column names in the table definition") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int, a string) USING json") - } - assert(e.message == "Found duplicate column(s) in table definition of `tbl`: a") - - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - val e2 = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int, A string) USING json") + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT, $c1 INT) USING parquet") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the table definition of `t`")) } - assert(e2.message == "Found duplicate column(s) in table definition of `tbl`: a") } } @@ -459,17 +464,33 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create table - column repeated in partition columns") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int) USING json PARTITIONED BY (a, a)") + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT) USING parquet PARTITIONED BY ($c0, $c1)") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the partition schema")) + } } - assert(e.message == "Found duplicate column(s) in partition: a") } - test("create table - column repeated in bucket columns") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int) USING json CLUSTERED BY (a, a) INTO 4 BUCKETS") + test("create table - column repeated in bucket/sort columns") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT) USING parquet CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the bucket definition")) + + errMsg = intercept[AnalysisException] { + sql(s""" + |CREATE TABLE t($c0 INT, col INT) USING parquet CLUSTERED BY (col) + | SORTED BY ($c0, $c1) INTO 2 BUCKETS + """.stripMargin) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the sort definition")) + } } - assert(e.message == "Found duplicate column(s) in bucket: a") } test("Refresh table after changing the data source table partitioning") { @@ -521,6 +542,17 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create view - duplicate column names in the view definition") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE VIEW t AS SELECT * FROM VALUES (1, 1) AS t($c0, $c1)") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the view definition")) + } + } + } + test("Alter/Describe Database") { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") @@ -695,7 +727,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withView("testview") { sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + - s"OPTIONS (PATH '$tmpFile')") + s"OPTIONS (PATH '${tmpFile.toURI}')") checkAnswer( sql("select c1, c2 from testview order by c1 limit 1"), @@ -707,7 +739,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TEMPORARY VIEW testview |USING org.apache.spark.sql.execution.datasources.csv.CSVFileFormat - |OPTIONS (PATH '$tmpFile') + |OPTIONS (PATH '${tmpFile.toURI}') """.stripMargin) } } @@ -751,7 +783,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val df = (1 to 2).map { i => (i, i.toString) }.toDF("age", "name") df.write.insertInto("students") spark.catalog.cacheTable("students") - assume(spark.table("students").collect().toSeq == df.collect().toSeq, "bad test: wrong data") + checkAnswer(spark.table("students"), df) assume(spark.catalog.isCached("students"), "bad test: table was not cached in the first place") sql("ALTER TABLE students RENAME TO teachers") sql("CREATE TABLE students (age INT, name STRING) USING parquet") @@ -760,10 +792,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(!spark.catalog.isCached("students")) assert(spark.catalog.isCached("teachers")) assert(spark.table("students").collect().isEmpty) - assert(spark.table("teachers").collect().toSeq == df.collect().toSeq) + checkAnswer(spark.table("teachers"), df) } - test("rename temporary table - destination table with database name") { + test("rename temporary view - destination table with database name") { withTempView("tab1") { sql( """ @@ -780,7 +812,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE tab1 RENAME TO default.tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to '`default`.`tab2`': " + + "RENAME TEMPORARY VIEW from '`tab1`' to '`default`.`tab2`': " + "cannot specify database name 'default' in the destination table")) val catalog = spark.sessionState.catalog @@ -788,7 +820,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("rename temporary table") { + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") sql("ALTER TABLE tab1 RENAME TO tab2") @@ -800,7 +832,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("rename temporary table - destination table already exists") { + test("rename temporary view - destination table already exists") { withTempView("tab1", "tab2") { sql( """ @@ -828,39 +860,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE tab1 RENAME TO tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to '`tab2`': destination table already exists")) + "RENAME TEMPORARY VIEW from '`tab1`' to '`tab2`': destination table already exists")) val catalog = spark.sessionState.catalog assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"), TableIdentifier("tab2"))) } } - test("alter table: set location") { - testSetLocation(isDatasourceTable = false) - } - - test("alter table: set properties") { - testSetProperties(isDatasourceTable = false) - } - - test("alter table: unset properties") { - testUnsetProperties(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde") { - testSetSerde(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde partition") { - testSetSerdePartition(isDatasourceTable = false) - } - - test("alter table: change column") { - testChangeColumn(isDatasourceTable = false) - } - test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -885,10 +891,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } - test("alter table: add partition") { - testAddPartitions(isDatasourceTable = false) - } - test("alter table: recover partitions (sequential)") { withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { testRecoverPartitions() @@ -957,17 +959,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } - test("alter table: drop partition") { - testDropPartitions(isDatasourceTable = false) - } - test("alter table: drop partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") } - test("alter table: rename partition") { - testRenamePartitions(isDatasourceTable = false) - } test("show databases") { sql("CREATE DATABASE showdb2B") @@ -1011,18 +1006,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(catalog.listTables("default") == Nil) } - test("drop table") { - testDropTable(isDatasourceTable = false) - } - protected def testDropTable(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) assert(catalog.listTables("dbx") == Seq(tableIdent)) sql("DROP TABLE dbx.tab1") assert(catalog.listTables("dbx") == Nil) @@ -1046,22 +1037,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } - private def convertToDatasourceTable( - catalog: SessionCatalog, - tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( - provider = Some("csv"))) - assert(catalog.getTableMetadata(tableIdent).provider == Some("csv")) - } - protected def testSetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1084,13 +1067,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1121,15 +1104,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetLocation(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, partSpec, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) @@ -1171,13 +1154,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerde(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties if (isUsingHiveMetastore) { @@ -1187,8 +1170,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getTableMetadata(tableIdent).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(expectedSerde)) } else { assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) } @@ -1229,18 +1216,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val spec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, spec, tableIdent) createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties if (isUsingHiveMetastore) { @@ -1250,8 +1237,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getPartition(tableIdent, spec).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some(expectedSerde)) } else { assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) } @@ -1295,6 +1286,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testAddPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1303,11 +1297,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // basic add partition @@ -1354,6 +1345,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testDropPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1362,7 +1356,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) @@ -1370,9 +1364,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { createTablePartition(catalog, part5, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3, part4, part5)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic drop partition sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") @@ -1407,20 +1398,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "q") val part2 = Map("a" -> "2", "b" -> "c") val part3 = Map("a" -> "3", "b" -> "p") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic rename partition sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") @@ -1451,14 +1442,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testChangeColumn(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val resolver = spark.sessionState.conf.resolver val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getMetadata(colName: String): Metadata = { val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field => resolver(field.name, colName) @@ -1601,13 +1592,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("drop current database") { - sql("CREATE DATABASE temp") - sql("USE temp") - sql("DROP DATABASE temp") - val e = intercept[AnalysisException] { + withDatabase("temp") { + sql("CREATE DATABASE temp") + sql("USE temp") + sql("DROP DATABASE temp") + val e = intercept[AnalysisException] { sql("CREATE TABLE t (a INT, b INT) USING parquet") }.getMessage - assert(e.contains("Database 'temp' not found")) + assert(e.contains("Database 'temp' not found")) + } } test("drop default database") { @@ -1837,22 +1830,25 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("tbl"), Row(1)) val defaultTablePath = spark.sessionState.catalog .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get - - sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") - spark.catalog.refreshTable("tbl") - // SET LOCATION won't move data from previous table path to new table path. - assert(spark.table("tbl").count() == 0) - // the previous table path should be still there. - assert(new File(defaultTablePath).exists()) - - sql("INSERT INTO tbl SELECT 2") - checkAnswer(spark.table("tbl"), Row(2)) - // newly inserted data will go to the new table path. - assert(dir.listFiles().nonEmpty) - - sql("DROP TABLE tbl") - // the new table path will be removed after DROP TABLE. - assert(!dir.exists()) + try { + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") + spark.catalog.refreshTable("tbl") + // SET LOCATION won't move data from previous table path to new table path. + assert(spark.table("tbl").count() == 0) + // the previous table path should be still there. + assert(new File(defaultTablePath).exists()) + + sql("INSERT INTO tbl SELECT 2") + checkAnswer(spark.table("tbl"), Row(2)) + // newly inserted data will go to the new table path. + assert(dir.listFiles().nonEmpty) + + sql("DROP TABLE tbl") + // the new table path will be removed after DROP TABLE. + assert(!dir.exists()) + } finally { + Utils.deleteRecursively(new File(defaultTablePath)) + } } } } @@ -1864,7 +1860,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$dir") + |OPTIONS(path "${dir.toURI}") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1882,12 +1878,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("t"), Row("c", 1) :: Nil) val newDirFile = new File(dir, "x") - val newDir = newDirFile.getAbsolutePath + val newDir = newDirFile.toURI spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") spark.sessionState.catalog.refreshTable(TableIdentifier("t")) val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table1.location == new URI(newDir)) + assert(table1.location == newDir) assert(!newDirFile.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") @@ -1905,7 +1901,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "$dir" + |LOCATION "${dir.toURI}" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1931,7 +1927,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$dir") + |OPTIONS(path "${dir.toURI}") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1960,7 +1956,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "$dir" + |LOCATION "${dir.toURI}" """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1977,7 +1973,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("create datasource table with a non-existing location") { withTable("t", "t1") { withTempPath { dir => - spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '$dir'") + spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '${dir.toURI}'") val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1989,7 +1985,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } // partition table withTempPath { dir => - spark.sql(s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '$dir'") + spark.sql( + s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '${dir.toURI}'") val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -2014,7 +2011,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t |USING parquet - |LOCATION '$dir' + |LOCATION '${dir.toURI}' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -2030,7 +2027,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |CREATE TABLE t1 |USING parquet |PARTITIONED BY(a, b) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -2047,6 +2044,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => test(s"data source table:partition column name containing $specialChars") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + withTable("t") { withTempDir { dir => spark.sql( @@ -2054,14 +2055,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |CREATE TABLE t(a string, `$specialChars` string) |USING parquet |PARTITIONED BY(`$specialChars`) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' """.stripMargin) assert(dir.listFiles().isEmpty) spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" val partFile = new File(dir, partEscaped) - assert(partFile.listFiles().length >= 1) + assert(partFile.listFiles().nonEmpty) checkAnswer(spark.table("t"), Row("1", "2") :: Nil) } } @@ -2070,15 +2071,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for datasource table") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + withTable("t", "t1") { withTempDir { dir => val loc = new File(dir, specialChars) loc.mkdir() + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t(a string) |USING parquet - |LOCATION '$loc' + |LOCATION '$escapedLoc' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -2087,19 +2095,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(loc.listFiles().isEmpty) spark.sql("INSERT INTO TABLE t SELECT 1") - assert(loc.listFiles().length >= 1) + assert(loc.listFiles().nonEmpty) checkAnswer(spark.table("t"), Row("1") :: Nil) } withTempDir { dir => val loc = new File(dir, specialChars) loc.mkdir() + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t1(a string, b string) |USING parquet |PARTITIONED BY(b) - |LOCATION '$loc' + |LOCATION '$escapedLoc' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -2109,15 +2120,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(loc.listFiles().isEmpty) spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") val partFile = new File(loc, "b=2") - assert(partFile.listFiles().length >= 1) + assert(partFile.listFiles().nonEmpty) checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") assert(!partFile1.exists()) - val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") - assert(partFile2.listFiles().length >= 1) - checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + + if (!Utils.isWindows) { + // Actual path becomes "b=2017-03-03%2012%3A13%253A14" on Windows. + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().nonEmpty) + checkAnswer( + spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } } } } @@ -2125,11 +2141,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for database") { - try { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + + withDatabase ("tmpdb") { withTable("t") { withTempDir { dir => val loc = new File(dir, specialChars) - spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'") + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") + spark.sql(s"CREATE DATABASE tmpdb LOCATION '$escapedLoc'") spark.sql("USE tmpdb") import testImplicits._ @@ -2140,8 +2163,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(tblloc.listFiles().nonEmpty) } } - } finally { - spark.sql("DROP DATABASE IF EXISTS tmpdb") } } } @@ -2150,11 +2171,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withTable("t", "t1") { withTempDir { dir => assert(!dir.getAbsolutePath.startsWith("file:/")) + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedDir = dir.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t(a string) |USING parquet - |LOCATION '$dir' + |LOCATION '$escapedDir' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location.toString.startsWith("file:/")) @@ -2162,12 +2186,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withTempDir { dir => assert(!dir.getAbsolutePath.startsWith("file:/")) + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedDir = dir.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t1(a string, b string) |USING parquet |PARTITIONED BY(b) - |LOCATION '$dir' + |LOCATION '$escapedDir' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) assert(table.location.toString.startsWith("file:/")) @@ -2268,6 +2295,57 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create temporary function with if not exists") { + withUserDefinedFunction("func1" -> true) { + val sql1 = + """ + |CREATE TEMPORARY FUNCTION IF NOT EXISTS func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains("It is not allowed to define a TEMPORARY function with IF NOT EXISTS")) + } + } + + test("create function with both if not exists and replace") { + withUserDefinedFunction("func1" -> false) { + val sql1 = + """ + |CREATE OR REPLACE FUNCTION IF NOT EXISTS func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains("CREATE FUNCTION with both IF NOT EXISTS and REPLACE is not allowed")) + } + } + + test("create temporary function by specifying a database") { + val dbName = "mydb" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + withUserDefinedFunction("func1" -> true) { + val sql1 = + s""" + |CREATE TEMPORARY FUNCTION $dbName.func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains(s"Specifying a database in CREATE TEMPORARY FUNCTION " + + s"is not allowed: '$dbName'")) + } + } + } + Seq(true, false).foreach { caseSensitive => test(s"alter table add columns with existing column name - caseSensitive $caseSensitive") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { @@ -2279,18 +2357,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { }.getMessage assert(e.contains("Found duplicate column(s)")) } else { - if (isUsingHiveMetastore) { - // hive catalog will still complains that c1 is duplicate column name because hive - // identifiers are case insensitive. - val e = intercept[AnalysisException] { - sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") - }.getMessage - assert(e.contains("HiveException")) - } else { - sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") - assert(spark.table("t1").schema - .equals(new StructType().add("c1", IntegerType).add("C1", StringType))) - } + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + assert(spark.table("t1").schema == + new StructType().add("c1", IntegerType).add("C1", StringType)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala new file mode 100644 index 000000000000..a0c1ea63d382 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +class FileFormatWriterSuite extends QueryTest with SharedSQLContext { + + test("empty file should be skipped while write to file") { + withTempPath { path => + spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString) + val partFiles = path.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 2) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8703fe96e587..c1d61b843d89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -190,7 +190,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi checkDataFilters(Set.empty) // Only one file should be read. - checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions => + checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 2")) { partitions => assert(partitions.size == 1, "when checking partitions") assert(partitions.head.files.size == 1, "when checking files in partition 1") assert(partitions.head.files.head.partitionValues.getInt(0) == 1, @@ -217,7 +217,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi checkDataFilters(Set.empty) // Only one file should be read. - checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions => + checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 2")) { partitions => assert(partitions.size == 1, "when checking partitions") assert(partitions.head.files.size == 1, "when checking files in partition 1") assert(partitions.head.files.head.partitionValues.getInt(0) == 1, @@ -235,13 +235,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi "p1=1/file1" -> 10, "p1=2/file2" -> 10)) - val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") + val df1 = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") // Filter on data only are advisory so we have to reevaluate. - assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1")) - // Need to evalaute filters that are not pushed down. - assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2")) + assert(getPhysicalFilters(df1) contains resolve(df1, "c1 = 1")) // Don't reevaluate partition only filters. - assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1"))) + assert(!(getPhysicalFilters(df1) contains resolve(df1, "p1 = 1"))) + + val df2 = table.where("(p1 + c2) = 2 AND c1 = 1") + // Filter on data only are advisory so we have to reevaluate. + assert(getPhysicalFilters(df2) contains resolve(df2, "c1 = 1")) + // Need to evaluate filters that are not pushed down. + assert(getPhysicalFilters(df2) contains resolve(df2, "(p1 + c2) = 2")) } test("bucketed table") { @@ -395,7 +399,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val fileCatalog = new InMemoryFileIndex( sparkSession = spark, - rootPaths = Seq(new Path(tempDir)), + rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], partitionSchema = None) // This should not fail. @@ -552,7 +556,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi if (buckets > 0) { val bucketed = df.queryExecution.analyzed transform { - case l @ LogicalRelation(r: HadoopFsRelation, _, _) => + case l @ LogicalRelation(r: HadoopFsRelation, _, _, _) => l.copy(relation = r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))(r.sparkSession)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index becb3aa27040..caf03885e387 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { }) val totalSize = allFiles.map(_.length()).sum val df = spark.read.parquet(dir.toString) - assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize)) + assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 352dba79a4c0..e439699605ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -24,8 +24,8 @@ import java.text.SimpleDateFormat import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat -import org.apache.hadoop.io.compress.GzipCodec import org.apache.hadoop.io.SequenceFile.CompressionType +import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} @@ -261,10 +261,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val cars = spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -284,11 +284,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val exception = intercept[SparkException] { spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() } @@ -990,13 +990,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val schema = new StructType().add("a", IntegerType).add("b", TimestampType) // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read .option("mode", "abcd") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema) .csv(testFile(valueMalformedFile)) checkAnswer(df1, @@ -1011,7 +1011,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) checkAnswer(df2, @@ -1028,7 +1028,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) checkAnswer(df3, @@ -1041,7 +1041,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) .csv(testFile(valueMalformedFile)) .collect @@ -1073,7 +1073,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.read .option("header", true) - .option("wholeFile", true) + .option("multiLine", true) .csv(path.getAbsolutePath) // Check if headers have new lines in the names. @@ -1096,10 +1096,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Empty file produces empty dataframe with empty schema") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) @@ -1174,4 +1174,75 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } } + + test("SPARK-21263: Invalid float and double are handled correctly in different modes") { + val exception = intercept[SparkException] { + spark.read.schema("a DOUBLE") + .option("mode", "FAILFAST") + .csv(Seq("10u12").toDS()) + .collect() + } + assert(exception.getMessage.contains("""input string: "10u12"""")) + + val count = spark.read.schema("a FLOAT") + .option("mode", "DROPMALFORMED") + .csv(Seq("10u12").toDS()) + .count() + assert(count == 0) + + val results = spark.read.schema("a FLOAT") + .option("mode", "PERMISSIVE") + .csv(Seq("10u12").toDS()) + checkAnswer(results, Row(null)) + } + + test("SPARK-20978: Fill the malformed column when the number of tokens is less than schema") { + val df = spark.read + .schema("a string, b string, unparsed string") + .option("columnNameOfCorruptRecord", "unparsed") + .csv(Seq("a").toDS()) + checkAnswer(df, Row("a", null, "a")) + } + + test("SPARK-21610: Corrupt records are not handled properly when creating a dataframe " + + "from a file") { + val columnNameOfCorruptRecord = "_corrupt_record" + val schema = new StructType() + .add("a", IntegerType) + .add("b", TimestampType) + .add(columnNameOfCorruptRecord, StringType) + // negative cases + val msg = intercept[AnalysisException] { + spark + .read + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .csv(testFile(valueMalformedFile)) + .select(columnNameOfCorruptRecord) + .collect() + }.getMessage + assert(msg.contains("only include the internal corrupt record column")) + intercept[org.apache.spark.sql.catalyst.errors.TreeNodeException[_]] { + spark + .read + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .csv(testFile(valueMalformedFile)) + .filter($"_corrupt_record".isNotNull) + .count() + } + // workaround + val df = spark + .read + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .csv(testFile(valueMalformedFile)) + .cache() + assert(df.filter($"_corrupt_record".isNotNull).count() == 1) + assert(df.filter($"_corrupt_record".isNull).count() == 1) + checkAnswer( + df.select(columnNameOfCorruptRecord), + Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index a74b22a4a88a..efbf73534bd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -130,16 +130,17 @@ class UnivocityParserSuite extends SparkFunSuite { DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) } - test("Float and Double Types are cast without respect to platform default Locale") { - val originalLocale = Locale.getDefault - try { - Locale.setDefault(new Locale("fr", "FR")) - // Would parse as 1.0 in fr-FR - val options = new CSVOptions(Map.empty[String, String], "GMT") - assert(parser.makeConverter("_1", FloatType, options = options).apply("1,00") == 100.0) - assert(parser.makeConverter("_1", DoubleType, options = options).apply("1,00") == 100.0) - } finally { - Locale.setDefault(originalLocale) + test("Throws exception for casting an invalid string to Float and Double Types") { + val options = new CSVOptions(Map.empty[String, String], "GMT") + val types = Seq(DoubleType, FloatType) + val input = Seq("10u000", "abc", "1 2/3") + types.foreach { dt => + input.foreach { v => + val message = intercept[NumberFormatException] { + parser.makeConverter("_1", dt, options = options).apply(v) + }.getMessage + assert(message.contains(v)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala new file mode 100644 index 000000000000..7d277c1ffaff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.types._ + +class JdbcUtilsSuite extends SparkFunSuite { + + val tableSchema = StructType(Seq( + StructField("C1", StringType, false), StructField("C2", IntegerType, false))) + val caseSensitive = org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + val caseInsensitive = org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + + test("Parse user specified column types") { + assert(JdbcUtils.getCustomSchema(tableSchema, null, caseInsensitive) === tableSchema) + assert(JdbcUtils.getCustomSchema(tableSchema, "", caseInsensitive) === tableSchema) + + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE", caseInsensitive) === + StructType(Seq(StructField("C1", DateType, false), StructField("C2", IntegerType, false)))) + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE", caseSensitive) === + StructType(Seq(StructField("C1", StringType, false), StructField("C2", IntegerType, false)))) + + assert( + JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseInsensitive) === + StructType(Seq(StructField("C1", DateType, false), StructField("C2", StringType, false)))) + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) === + StructType(Seq(StructField("C1", StringType, false), StructField("C2", StringType, false)))) + + // Throw AnalysisException + val duplicate = intercept[AnalysisException]{ + JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, c1 STRING", caseInsensitive) === + StructType(Seq(StructField("c1", DateType, false), StructField("c1", StringType, false))) + } + assert(duplicate.getMessage.contains( + "Found duplicate column(s) in the customSchema option value")) + + // Throw ParseException + val dataTypeNotSupported = intercept[ParseException]{ + JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", caseInsensitive) === + StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false))) + } + assert(dataTypeNotSupported.getMessage.contains("DataType datee is not supported")) + + val mismatchedInput = intercept[ParseException]{ + JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) === + StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false))) + } + assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 6e2b4f0df595..316c5183fddf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -72,6 +72,21 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { assert(df.first().getString(0) == "Reynold Xin") } + test("allowUnquotedControlChars off") { + val str = """{"name": "a\u0001b"}""" + val df = spark.read.json(Seq(str).toDS()) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowUnquotedControlChars on") { + val str = """{"name": "a\u0001b"}""" + val df = spark.read.option("allowUnquotedControlChars", "true").json(Seq(str).toDS()) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "a\u0001b") + } + test("allowNumericLeadingZeros off") { val str = """{"age": 0018}""" val df = spark.read.json(Seq(str).toDS()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 2ab03819964b..8c8d41ebf115 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -20,17 +20,19 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Locale import com.fasterxml.jackson.core.JsonFactory import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.rdd.RDD import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf @@ -824,7 +826,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("b", DecimalType(2, 2), true):: Nil) assert(expectedSchema === jsonDF.schema) - checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal(0.01))) + checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal("0.01"))) val mergedJsonDF = spark.read .option("prefersDecimal", "true") @@ -837,7 +839,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedMergedSchema === mergedJsonDF.schema) checkAnswer( mergedJsonDF, - Row(1.0E-39D, BigDecimal(0.01)) :: + Row(1.0E-39D, BigDecimal("0.01")) :: Row(1.0E38D, BigDecimal("92233720368547758070")) :: Nil ) } @@ -935,14 +937,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Map("e" -> null)) :: Nil ) - checkAnswer( - sql("select `map`['c'] from jsonWithSimpleMap"), - Row(null) :: - Row(null) :: - Row(3) :: - Row(1) :: - Row(null) :: Nil - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer( + sql("select `map`['c'] from jsonWithSimpleMap"), + Row(null) :: + Row(null) :: + Row(3) :: + Row(1) :: + Row(null) :: Nil + ) + } val innerStruct = StructType( StructField("field1", ArrayType(IntegerType, true), true) :: @@ -964,15 +968,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Map("f" -> Row(null, null))) :: Nil ) - checkAnswer( - sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"), - Row(Seq(1, 2, 3, null), null) :: - Row(null, null) :: - Row(null, 4) :: - Row(null, 3) :: - Row(null, null) :: - Row(null, null) :: Nil - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer( + sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"), + Row(Seq(1, 2, 3, null), null) :: + Row(null, null) :: + Row(null, 4) :: + Row(null, 3) :: + Row(null, null) :: + Row(null, null) :: Nil + ) + } } test("SPARK-2096 Correctly parse dot notations") { @@ -1034,24 +1040,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Corrupt records: FAILFAST mode") { - val schema = StructType( - StructField("a", StringType, true) :: Nil) // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { spark.read .option("mode", "FAILFAST") .json(corruptRecords) - } - assert(exceptionOne.getMessage.contains("JsonParseException")) + }.getMessage + assert(exceptionOne.contains( + "Malformed records are detected in schema inference. Parse Mode: FAILFAST.")) val exceptionTwo = intercept[SparkException] { spark.read .option("mode", "FAILFAST") - .schema(schema) + .schema("a string") .json(corruptRecords) .collect() - } - assert(exceptionTwo.getMessage.contains("JsonParseException")) + }.getMessage + assert(exceptionTwo.contains( + "Malformed records are detected in record parsing. Parse Mode: FAILFAST.")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1326,6 +1332,15 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Dataset toJSON doesn't construct rdd") { + val containsRDD = spark.emptyDataFrame.toJSON.queryExecution.logical.find { + case ExternalRDD(_, _) => true + case _ => false + } + + assert(containsRDD.isEmpty, "Expected logical plan of toJSON to not contain an RDD") + } + test("JSONRelation equality test") { withTempPath(dir => { val path = dir.getCanonicalFile.toURI.toString @@ -1803,7 +1818,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(new File(path).listFiles().exists(_.getName.endsWith(".gz"))) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .option("compression", "gZiP") @@ -1825,7 +1840,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write.json(jsonDir) @@ -1854,7 +1869,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) // no corrupt record column should be created assert(jsonDF.schema === StructType(Seq())) // only the first object should be read @@ -1875,7 +1890,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) assert(jsonDF.count() === corruptRecordCount) assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) @@ -1906,7 +1921,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "DROPMALFORMED").json(path) checkAnswer(jsonDF, Seq(Row("test"))) } } @@ -1929,21 +1944,23 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .json(path) } - assert(exceptionOne.getMessage.contains("Failed to infer a common schema")) + assert(exceptionOne.getMessage.contains("Malformed records are detected in schema " + + "inference. Parse Mode: FAILFAST.")) val exceptionTwo = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .schema(schema) .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("Failed to parse a value")) + assert(exceptionTwo.getMessage.contains("Malformed records are detected in record " + + "parsing. Parse Mode: FAILFAST.")) } } @@ -1978,4 +1995,72 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) } } + + test("SPARK-18772: Parse special floats correctly") { + val jsons = Seq( + """{"a": "NaN"}""", + """{"a": "Infinity"}""", + """{"a": "-Infinity"}""") + + // positive cases + val checks: Seq[Double => Boolean] = Seq( + _.isNaN, + _.isPosInfinity, + _.isNegInfinity) + + Seq(FloatType, DoubleType).foreach { dt => + jsons.zip(checks).foreach { case (json, check) => + val ds = spark.read + .schema(StructType(Seq(StructField("a", dt)))) + .json(Seq(json).toDS()) + .select($"a".cast(DoubleType)).as[Double] + assert(check(ds.first())) + } + } + + // negative cases + Seq(FloatType, DoubleType).foreach { dt => + val lowerCasedJsons = jsons.map(_.toLowerCase(Locale.ROOT)) + // The special floats are case-sensitive so these cases below throw exceptions. + lowerCasedJsons.foreach { lowerCasedJson => + val e = intercept[SparkException] { + spark.read + .option("mode", "FAILFAST") + .schema(StructType(Seq(StructField("a", dt)))) + .json(Seq(lowerCasedJson).toDS()) + .collect() + } + assert(e.getMessage.contains("Cannot parse")) + } + } + } + + test("SPARK-21610: Corrupt records are not handled properly when creating a dataframe " + + "from a file") { + withTempPath { dir => + val path = dir.getCanonicalPath + val data = + """{"field": 1} + |{"field": 2} + |{"field": "3"}""".stripMargin + Seq(data).toDF().repartition(1).write.text(path) + val schema = new StructType().add("field", ByteType).add("_corrupt_record", StringType) + // negative cases + val msg = intercept[AnalysisException] { + spark.read.schema(schema).json(path).select("_corrupt_record").collect() + }.getMessage + assert(msg.contains("only include the internal corrupt record column")) + intercept[catalyst.errors.TreeNodeException[_]] { + spark.read.schema(schema).json(path).filter($"_corrupt_record".isNotNull).count() + } + // workaround + val df = spark.read.schema(schema).json(path).cache() + assert(df.filter($"_corrupt_record".isNotNull).count() == 1) + assert(df.filter($"_corrupt_record".isNull).count() == 2) + checkAnswer( + df.select("_corrupt_record"), + Row(null) :: Row(null) :: Row("{\"field\": \"3\"}") :: Nil + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index dd53b561326f..90f6620d990c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -63,7 +63,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex var maybeRelation: Option[HadoopFsRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(relation: HadoopFsRelation, _, _)) => + case PhysicalOperation(_, filters, + LogicalRelation(relation: HadoopFsRelation, _, _, _)) => maybeRelation = Some(relation) filters }.flatten.reduceLeftOption(_ && _) @@ -505,7 +506,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex sparkContext.register(accu) val df = spark.read.parquet(path).filter("a < 100") - df.foreachPartition(_.foreach(v => accu.add(0))) + df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) df.collect if (enablePushDown) { @@ -538,6 +539,22 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // scalastyle:on nonascii } } + + test("SPARK-20364: Disable Parquet predicate pushdown for fields having dots in the names") { + import testImplicits._ + + Seq(true, false).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString, + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString, + SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + withTempPath { path => + Seq(Some(1), None).toDF("col.dots").write.parquet(path.getAbsolutePath) + val readBack = spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NOT NULL") + assert(readBack.count() == 1) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 94a2f9a00b3f..d76990b482db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -211,7 +211,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } testStandardAndLegacyModes("nested map with struct as value type") { - val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) + val data = (1 to 4).map(i => Tuple1(Map(i -> ((i, s"val_$i"))))) withParquetDataFrame(data) { df => checkAnswer(df, data.map { case Tuple1(m) => Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index b4f3de996120..f79b92b804c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -650,7 +651,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation( - HadoopFsRelation(location: PartitioningAwareFileIndex, _, _, _, _, _), _, _) => + HadoopFsRelation(location: PartitioningAwareFileIndex, _, _, _, _, _), _, _, _) => assert(location.partitionSpec() === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a matching HadoopFsRelation, but got:\n$queryExecution") @@ -676,7 +677,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha 1.5.toFloat, 4.5, new java.math.BigDecimal(new BigInteger("212500"), 5), - new java.math.BigDecimal(2.125), + new java.math.BigDecimal("2.125"), java.sql.Date.valueOf("2015-05-23"), new Timestamp(0), "This is a string, /[]?=:", @@ -1022,4 +1023,48 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } } + + test("SPARK-21463: MetadataLogFileIndex should respect userSpecifiedSchema for partition cols") { + withTempDir { tempDir => + val output = new File(tempDir, "output").toString + val checkpoint = new File(tempDir, "chkpoint").toString + try { + val stream = MemoryStream[(String, Int)] + val df = stream.toDS().toDF("time", "value") + val sq = df.writeStream + .option("checkpointLocation", checkpoint) + .format("parquet") + .partitionBy("time") + .start(output) + + stream.addData(("2017-01-01-00", 1), ("2017-01-01-01", 2)) + sq.processAllAvailable() + + val schema = new StructType() + .add("time", StringType) + .add("value", IntegerType) + val readBack = spark.read.schema(schema).parquet(output) + assert(readBack.schema.toSet === schema.toSet) + + checkAnswer( + readBack, + Seq(Row("2017-01-01-00", 1), Row("2017-01-01-01", 2)) + ) + } finally { + spark.streams.active.foreach(_.stop()) + } + } + } + + test("SPARK-22109: Resolve type conflicts between strings and timestamps in partition column") { + val df = Seq( + (1, "2015-01-01 00:00:00"), + (2, "2014-01-01 00:00:00"), + (3, "blah")).toDF("i", "str") + + withTempPath { path => + df.write.format("parquet").partitionBy("str").save(path.getAbsolutePath) + checkAnswer(spark.read.load(path.getAbsolutePath), df) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 487d7a7e5ac8..0917f188b979 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import scala.util.Try import org.apache.spark.SparkConf -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{Benchmark, Utils} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4fc52c99fbee..adcaf2d76519 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -38,4 +38,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } + + test("debugCodegenStringSeq") { + val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + assert(res.length == 2) + assert(res.forall{ case (subtree, code) => + subtree.contains("Range") && code.contains("Object[]")}) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 26c45e092dc6..a0fad862b44c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,8 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -157,7 +157,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } test("broadcast hint in SQL") { - import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} + import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join} spark.range(10).createOrReplaceTempView("t") spark.range(10).createOrReplaceTempView("u") @@ -170,12 +170,12 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution .optimizedPlan - assert(plan1.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(!plan1.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) - assert(!plan2.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(plan2.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) - assert(!plan3.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(!plan3.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(plan1.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan1.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + assert(!plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + assert(!plan3.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan3.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2ce7db6a22c0..0dc612ef735f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,76 +19,32 @@ package org.apache.spark.sql.execution.metric import java.io.File -import scala.collection.mutable.HashMap +import scala.util.Random import org.apache.spark.SparkFunSuite -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{AccumulatorContext, JsonProtocol} -class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { +class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with SharedSQLContext { import testImplicits._ + /** - * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". - * - * @param df `DataFrame` to run - * @param expectedNumOfJobs number of jobs that will run - * @param expectedMetrics the expected metrics. The format is - * `nodeId -> (operatorName, metric name -> metric value)`. + * Generates a `DataFrame` by filling randomly generated bytes for hash collision. */ - private def testSparkPlanMetrics( - df: DataFrame, - expectedNumOfJobs: Int, - expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - df.collect() - } - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = - spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= expectedNumOfJobs) - if (jobs.size == expectedNumOfJobs) { - // If we can track all jobs, check the metric values - val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( - df.queryExecution.executedPlan)).allNodes.filter { node => - expectedMetrics.contains(node.id) - }.map { node => - val nodeMetrics = node.metrics.map { metric => - val metricValue = metricValues(metric.accumulatorId) - (metric.name, metricValue) - }.toMap - (node.id, node.name -> nodeMetrics) - }.toMap - - assert(expectedMetrics.keySet === actualMetrics.keySet) - for (nodeId <- expectedMetrics.keySet) { - val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) - val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) - assert(expectedNodeName === actualNodeName) - for (metricName <- expectedMetricsMap.keySet) { - assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) - } - } - } else { - // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. - // Since we cannot track all jobs, the metric values could be wrong and we should not check - // them. - logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + private def generateRandomBytesDF(numRows: Int = 65535): DataFrame = { + val random = new Random() + val manyBytes = (0 until numRows).map { _ => + val byteArrSize = random.nextInt(100) + val bytes = new Array[Byte](byteArrSize) + random.nextBytes(bytes) + (bytes, random.nextInt(100)) } + manyBytes.toSeq.toDF("a", "b") } test("LocalTableScanExec computes metrics in collect and take") { @@ -112,8 +68,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) val df = person.filter('age < 25) testSparkPlanMetrics(df, 1, Map( - 0L -> ("Filter", Map( - "number of output rows" -> 1L))) + 0L -> (("Filter", Map( + "number of output rows" -> 1L)))) ) } @@ -130,16 +86,85 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // ... -> HashAggregate(nodeId = 2) -> Exchange(nodeId = 1) // -> HashAggregate(nodeId = 0) val df = testData2.groupBy().count() // 2 partitions + val expected1 = Seq( + Map("number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), + Map("number of output rows" -> 1L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df, 1, Map( - 2L -> ("HashAggregate", Map("number of output rows" -> 2L)), - 0L -> ("HashAggregate", Map("number of output rows" -> 1L))) + 2L -> (("HashAggregate", expected1(0))), + 0L -> (("HashAggregate", expected1(1)))) ) // 2 partitions and each partition contains 2 keys val df2 = testData2.groupBy('a).count() + val expected2 = Seq( + Map("number of output rows" -> 4L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), + Map("number of output rows" -> 3L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) + testSparkPlanMetrics(df2, 1, Map( + 2L -> (("HashAggregate", expected2(0))), + 0L -> (("HashAggregate", expected2(1)))) + ) + } + + test("Aggregate metrics: track avg probe") { + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) + } else { + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + + test("ObjectHashAggregate metrics") { + // Assume the execution plan is + // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> ObjectHashAggregate(nodeId = 0) + val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> (("ObjectHashAggregate", Map("number of output rows" -> 2L))), + 0L -> (("ObjectHashAggregate", Map("number of output rows" -> 1L)))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).agg(collect_set('a)) testSparkPlanMetrics(df2, 1, Map( - 2L -> ("HashAggregate", Map("number of output rows" -> 4L)), - 0L -> ("HashAggregate", Map("number of output rows" -> 3L))) + 2L -> (("ObjectHashAggregate", Map("number of output rows" -> 4L))), + 0L -> (("ObjectHashAggregate", Map("number of output rows" -> 3L)))) ) } @@ -161,9 +186,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( - 0L -> ("SortMergeJoin", Map( + 0L -> (("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of output rows" -> 4L))) + "number of output rows" -> 4L)))) ) } } @@ -179,17 +204,17 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = spark.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( - 0L -> ("SortMergeJoin", Map( + 0L -> (("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of output rows" -> 8L))) + "number of output rows" -> 8L)))) ) val df2 = spark.sql( "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( - 0L -> ("SortMergeJoin", Map( + 0L -> (("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of output rows" -> 8L))) + "number of output rows" -> 8L)))) ) } } @@ -201,11 +226,121 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( - 1L -> ("BroadcastHashJoin", Map( - "number of output rows" -> 2L))) + 1L -> (("BroadcastHashJoin", Map( + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))) ) } + test("BroadcastHashJoin metrics: track avg probe") { + // The executed plan looks like: + // Project [a#210, b#211, b#221] + // +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight + // :- Project [_1#207 AS a#210, _2#208 AS b#211] + // : +- Filter isnotnull(_1#207) + // : +- LocalTableScan [_1#207, _2#208] + // +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true])) + // +- Project [_1#217 AS a#220, _2#218 AS b#221] + // +- Filter isnotnull(_1#217) + // +- LocalTableScan [_1#217, _2#218] + // + // Assume the execution plan with node id is + // WholeStageCodegen disabled: + // Project(nodeId = 0) + // BroadcastHashJoin(nodeId = 1) + // ...(ignored) + // + // WholeStageCodegen enabled: + // WholeStageCodegen(nodeId = 0) + // Project(nodeId = 1) + // BroadcastHashJoin(nodeId = 2) + // Project(nodeId = 3) + // Filter(nodeId = 4) + // ...(ignored) + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF() + val df2 = generateRandomBytesDF() + val df = df1.join(broadcast(df2), "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40", + "spark.sql.shuffle.partitions" -> "2", + "spark.sql.join.preferSortMergeJoin" -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value") + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0) + val df = df1.join(df2, "key") + val metrics = getSparkPlanMetrics(df, 1, Set(1L)) + testSparkPlanMetrics(df, 1, Map( + 1L -> (("ShuffledHashJoin", Map( + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))) + ) + } + } + + test("ShuffledHashJoin metrics: track avg probe") { + // The executed plan looks like: + // Project [a#308, b#309, b#319] + // +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight + // :- Exchange hashpartitioning(a#308, 2) + // : +- Project [_1#305 AS a#308, _2#306 AS b#309] + // : +- Filter isnotnull(_1#305) + // : +- LocalTableScan [_1#305, _2#306] + // +- Exchange hashpartitioning(a#318, 2) + // +- Project [_1#315 AS a#318, _2#316 AS b#319] + // +- Filter isnotnull(_1#315) + // +- LocalTableScan [_1#315, _2#316] + // + // Assume the execution plan with node id is + // WholeStageCodegen disabled: + // Project(nodeId = 0) + // ShuffledHashJoin(nodeId = 1) + // ...(ignored) + // + // WholeStageCodegen enabled: + // WholeStageCodegen(nodeId = 0) + // Project(nodeId = 1) + // ShuffledHashJoin(nodeId = 2) + // ...(ignored) + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000", + "spark.sql.shuffle.partitions" -> "2", + "spark.sql.join.preferSortMergeJoin" -> "false") { + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF(65535 * 5) + val df2 = generateRandomBytesDF(65535) + val df = df1.join(df2, "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + } + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") @@ -213,14 +348,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastHashJoin", Map( - "number of output rows" -> 5L))) + 0L -> (("BroadcastHashJoin", Map( + "number of output rows" -> 5L)))) ) val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") testSparkPlanMetrics(df3, 2, Map( - 0L -> ("BroadcastHashJoin", Map( - "number of output rows" -> 6L))) + 0L -> (("BroadcastHashJoin", Map( + "number of output rows" -> 6L)))) ) } @@ -235,8 +370,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { "SELECT * FROM testData2 left JOIN testDataForJoin ON " + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") testSparkPlanMetrics(df, 3, Map( - 1L -> ("BroadcastNestedLoopJoin", Map( - "number of output rows" -> 12L))) + 1L -> (("BroadcastNestedLoopJoin", Map( + "number of output rows" -> 12L)))) ) } } @@ -249,8 +384,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastHashJoin", Map( - "number of output rows" -> 2L))) + 0L -> (("BroadcastHashJoin", Map( + "number of output rows" -> 2L)))) ) } @@ -264,18 +399,33 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin") testSparkPlanMetrics(df, 1, Map( - 0L -> ("CartesianProduct", Map("number of output rows" -> 12L))) + 0L -> (("CartesianProduct", Map("number of output rows" -> 12L)))) ) } } } + test("SortMergeJoin(left-anti) metrics") { + val anti = testData2.filter("a > 2") + withTempView("antiData") { + anti.createOrReplaceTempView("antiData") + val df = spark.sql( + "SELECT * FROM testData2 ANTI JOIN antiData ON testData2.a = antiData.a") + testSparkPlanMetrics(df, 1, Map( + 0L -> (("SortMergeJoin", Map("number of output rows" -> 4L)))) + ) + } + } + test("save metrics") { withTempPath { file => + // person creates a temporary view. get the DF before listing previous execution IDs + val data = person.select('name) + sparkContext.listenerBus.waitUntilEmpty(10000) val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) + data.write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) val executionIds = spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) @@ -342,75 +492,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil) } } -} - -object InputOutputMetricsHelper { - private class InputOutputMetricsListener extends SparkListener { - private case class MetricsResult( - var recordsRead: Long = 0L, - var shuffleRecordsRead: Long = 0L, - var sumMaxOutputRows: Long = 0L) - - private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult] - - def reset(): Unit = { - stageIdToMetricsResult.clear() - } - - /** - * Return a list of recorded metrics aggregated per stage. - * - * The list is sorted in the ascending order on the stageId. - * For each recorded stage, the following tuple is returned: - * - sum of inputMetrics.recordsRead for all the tasks in the stage - * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage - * - sum of the highest values of "number of output rows" metric for all the tasks in the stage - */ - def getResults(): List[(Long, Long, Long)] = { - stageIdToMetricsResult.keySet.toList.sorted.map { stageId => - val res = stageIdToMetricsResult(stageId) - (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows) - } - } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult()) - - res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead - res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead - - var maxOutputRows = 0L - for (accum <- taskEnd.taskMetrics.externalAccums) { - val info = accum.toInfo(Some(accum.value), None) - if (info.name.toString.contains("number of output rows")) { - info.update match { - case Some(n: Number) => - if (n.longValue() > maxOutputRows) { - maxOutputRows = n.longValue() - } - case _ => // Ignore. - } - } - } - res.sumMaxOutputRows += maxOutputRows - } + test("writing data out metrics: parquet") { + testMetricsNonDynamicPartition("parquet", "t1") } - // Run df.collect() and return aggregated metrics for each stage. - def run(df: DataFrame): List[(Long, Long, Long)] = { - val spark = df.sparkSession - val sparkContext = spark.sparkContext - val listener = new InputOutputMetricsListener() - sparkContext.addSparkListener(listener) - - try { - sparkContext.listenerBus.waitUntilEmpty(5000) - listener.reset() - df.collect() - sparkContext.listenerBus.waitUntilEmpty(5000) - } finally { - sparkContext.removeSparkListener(listener) - } - listener.getResults() + test("writing data out metrics with dynamic partition: parquet") { + testMetricsDynamicPartition("parquet", "parquet", "t1") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala new file mode 100644 index 000000000000..3966e98c1ce0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import java.io.File + +import scala.collection.mutable.HashMap + +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + + +trait SQLMetricsTestUtils extends SQLTestUtils { + + import testImplicits._ + + /** + * Get execution metrics for the SQL execution and verify metrics values. + * + * @param metricsValues the expected metric values (numFiles, numPartitions, numOutputRows). + * @param func the function can produce execution id after running. + */ + private def verifyWriteDataMetrics(metricsValues: Seq[Int])(func: => Unit): Unit = { + val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + // Run the given function to trigger query execution. + func + spark.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = + spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size == 1) + val executionId = executionIds.head + + val executionData = spark.sharedState.listener.getExecution(executionId).get + val executedNode = executionData.physicalPlanGraph.nodes.head + + val metricsNames = Seq( + "number of written files", + "number of dynamic part", + "number of output rows") + + val metrics = spark.sharedState.listener.getExecutionMetrics(executionId) + + metricsNames.zip(metricsValues).foreach { case (metricsName, expected) => + val sqlMetric = executedNode.metrics.find(_.name == metricsName) + assert(sqlMetric.isDefined) + val accumulatorId = sqlMetric.get.accumulatorId + val metricValue = metrics(accumulatorId).replaceAll(",", "").toInt + assert(metricValue == expected) + } + + val totalNumBytesMetric = executedNode.metrics.find(_.name == "bytes of written output").get + val totalNumBytes = metrics(totalNumBytesMetric.accumulatorId).replaceAll(",", "").toInt + assert(totalNumBytes > 0) + } + + protected def testMetricsNonDynamicPartition( + dataFormat: String, + tableName: String): Unit = { + withTable(tableName) { + Seq((1, 2)).toDF("i", "j") + .write.format(dataFormat).mode("overwrite").saveAsTable(tableName) + + val tableLocation = + new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location) + + // 2 files, 100 rows, 0 dynamic partition. + verifyWriteDataMetrics(Seq(2, 0, 100)) { + (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) + .write.format(dataFormat).mode("overwrite").insertInto(tableName) + } + assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) + } + } + + protected def testMetricsDynamicPartition( + provider: String, + dataFormat: String, + tableName: String): Unit = { + withTempPath { dir => + spark.sql( + s""" + |CREATE TABLE $tableName(a int, b int) + |USING $provider + |PARTITIONED BY(a) + |LOCATION '${dir.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + val df = spark.range(start = 0, end = 40, step = 1, numPartitions = 1) + .selectExpr("id a", "id b") + + // 40 files, 80 rows, 40 dynamic partitions. + verifyWriteDataMetrics(Seq(40, 40, 80)) { + df.union(df).repartition(2, $"a") + .write + .format(dataFormat) + .mode("overwrite") + .insertInto(tableName) + } + assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) + } + } + + /** + * Call `df.collect()` and collect necessary metrics from execution data. + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedNodeIds the node ids of the metrics to collect from execution data. + */ + protected def getSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedNodeIds: Set[Long], + enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = { + val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) { + df.collect() + } + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = + spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= expectedNumOfJobs) + if (jobs.size == expectedNumOfJobs) { + // If we can track all jobs, check the metric values + val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) + val metrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).allNodes.filter { node => + expectedNodeIds.contains(node.id) + }.map { node => + val nodeMetrics = node.metrics.map { metric => + val metricValue = metricValues(metric.accumulatorId) + (metric.name, metricValue) + }.toMap + (node.id, node.name -> nodeMetrics) + }.toMap + Some(metrics) + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + None + } + } + + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + protected def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet) + optActualMetrics.foreach { actualMetrics => + assert(expectedMetrics.keySet === actualMetrics.keySet) + for (nodeId <- expectedMetrics.keySet) { + val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) + val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) + assert(expectedNodeName === actualNodeName) + for (metricName <- expectedMetricsMap.keySet) { + assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) + } + } + } + } +} + + +object InputOutputMetricsHelper { + private class InputOutputMetricsListener extends SparkListener { + private case class MetricsResult( + var recordsRead: Long = 0L, + var shuffleRecordsRead: Long = 0L, + var sumMaxOutputRows: Long = 0L) + + private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult] + + def reset(): Unit = { + stageIdToMetricsResult.clear() + } + + /** + * Return a list of recorded metrics aggregated per stage. + * + * The list is sorted in the ascending order on the stageId. + * For each recorded stage, the following tuple is returned: + * - sum of inputMetrics.recordsRead for all the tasks in the stage + * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage + * - sum of the highest values of "number of output rows" metric for all the tasks in the stage + */ + def getResults(): List[(Long, Long, Long)] = { + stageIdToMetricsResult.keySet.toList.sorted.map { stageId => + val res = stageIdToMetricsResult(stageId) + (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult()) + + res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead + + var maxOutputRows = 0L + for (accum <- taskEnd.taskMetrics.externalAccums) { + val info = accum.toInfo(Some(accum.value), None) + if (info.name.toString.contains("number of output rows")) { + info.update match { + case Some(n: Number) => + if (n.longValue() > maxOutputRows) { + maxOutputRows = n.longValue() + } + case _ => // Ignore. + } + } + } + res.sumMaxOutputRows += maxOutputRows + } + } + + // Run df.collect() and return aggregated metrics for each stage. + def run(df: DataFrame): List[(Long, Long, Long)] = { + val spark = df.sparkSession + val sparkContext = spark.sparkContext + val listener = new InputOutputMetricsListener() + sparkContext.addSparkListener(listener) + + try { + sparkContext.listenerBus.waitUntilEmpty(5000) + listener.reset() + df.collect() + sparkContext.listenerBus.waitUntilEmpty(5000) + } finally { + sparkContext.removeSparkListener(listener) + } + listener.getResults() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 2a3d1cf0b298..153e6e1f88c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -21,7 +21,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, In} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType @@ -36,7 +37,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { } override def afterAll(): Unit = { - spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF") + spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF")) super.afterAll() } @@ -64,7 +65,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { test("Python UDF: no push down on non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") - .where("b > 4 and dummyPythonUDF(a) and rand() > 3") + .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), @@ -76,7 +77,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { test("Python UDF: no push down on predicates starting from the first non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") - .where("dummyPythonUDF(a) and rand() > 3 and b > 4") + .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f } @@ -104,5 +105,8 @@ class DummyUDF extends PythonFunction( broadcastVars = null, accumulator = null) -class MyDummyPythonUDF - extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType) +class MyDummyPythonUDF extends UserDefinedPythonFunction( + name = "dummyUDF", + func = new DummyUDF, + dataType = BooleanType, + vectorized = false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 3d480b148db5..83018f95aa55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -21,8 +21,8 @@ import java.io._ import java.nio.charset.StandardCharsets._ import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 7689bc03a4cc..48e70e48b179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -259,6 +259,23 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { fm.rename(path2, path3) } } + + test("verifyBatchIds") { + import HDFSMetadataLog.verifyBatchIds + verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) + verifyBatchIds(Seq(1L), Some(1L), Some(1L)) + verifyBatchIds(Seq(1L, 2L, 3L), None, Some(3L)) + verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), None) + verifyBatchIds(Seq(1L, 2L, 3L), None, None) + + intercept[IllegalStateException](verifyBatchIds(Seq(), Some(1L), None)) + intercept[IllegalStateException](verifyBatchIds(Seq(), None, Some(1L))) + intercept[IllegalStateException](verifyBatchIds(Seq(), Some(1L), Some(1L))) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), Some(1L), None)) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), None, Some(5L))) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), Some(1L), Some(5L))) + intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) + } } /** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 24a7b7740fa5..e8420eee7fe9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -216,15 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { // Before adding data, check output checkAnswer(sink.allData, Seq.empty) - assert(plan.stats(sqlConf).sizeInBytes === 0) + assert(plan.stats.sizeInBytes === 0) sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 12) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 24) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index dc556322bedd..e6cdc063c4e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -37,16 +37,18 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { } // None set - assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) + assert(new OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) // One set - assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) - assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) + assert(new OffsetSeqMetadata(1, 0, Map.empty) === + OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) + assert(new OffsetSeqMetadata(0, 2, Map.empty) === + OffsetSeqMetadata("""{"batchTimestampMs":2}""")) assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) === OffsetSeqMetadata(s"""{"conf": {"$key":2}}""")) // Two set - assert(OffsetSeqMetadata(1, 2, Map.empty) === + assert(new OffsetSeqMetadata(1, 2, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}""")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 007554a83f54..519e3c01afe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.eclipse.jetty.util.ConcurrentHashSet import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.PatienceConfiguration.Timeout -import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.TimeLimits._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 000000000000..03d0f63fa4d7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala index 5174a0415304..ec1154907365 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala @@ -65,20 +65,22 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before while (source.getOffset.isEmpty) { Thread.sleep(10) } - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - assert(batch1.as[String].collect().toSeq === Seq("hello")) + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val offset1 = source.getOffset.get + val batch1 = source.getBatch(None, offset1) + assert(batch1.as[String].collect().toSeq === Seq("hello")) + + serverThread.enqueue("world") + while (source.getOffset.get === offset1) { + Thread.sleep(10) + } + val offset2 = source.getOffset.get + val batch2 = source.getBatch(Some(offset1), offset2) + assert(batch2.as[String].collect().toSeq === Seq("world")) - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) + val both = source.getBatch(None, offset2) + assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - assert(batch2.as[String].collect().toSeq === Seq("world")) - - val both = source.getBatch(None, offset2) - assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) // Try stopping the source to make sure this does not block forever. source.stop() @@ -104,22 +106,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before while (source.getOffset.isEmpty) { Thread.sleep(10) } - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq - assert(batch1Seq.map(_._1) === Seq("hello")) - val batch1Stamp = batch1Seq(0)._2 - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val offset1 = source.getOffset.get + val batch1 = source.getBatch(None, offset1) + val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq + assert(batch1Seq.map(_._1) === Seq("hello")) + val batch1Stamp = batch1Seq(0)._2 + + serverThread.enqueue("world") + while (source.getOffset.get === offset1) { + Thread.sleep(10) + } + val offset2 = source.getOffset.get + val batch2 = source.getBatch(Some(offset1), offset2) + val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq + assert(batch2Seq.map(_._1) === Seq("world")) + val batch2Stamp = batch2Seq(0)._2 + assert(!batch2Stamp.before(batch1Stamp)) } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq - assert(batch2Seq.map(_._1) === Seq("world")) - val batch2Stamp = batch2Seq(0)._2 - assert(!batch2Stamp.before(batch1Stamp)) // Try stopping the source to make sure this does not block forever. source.stop() @@ -148,6 +152,21 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val exception = intercept[AnalysisException] { + provider.sourceSchema( + sqlContext, Some(userSpecifiedSchema), + "", + Map("host" -> "localhost", "port" -> "1234")) + } + assert(exception.getMessage.contains( + "socket source does not support a user-specified schema")) + } + test("no server up") { val provider = new TextSocketSourceProvider val parameters = Map("host" -> "localhost", "port" -> "0") @@ -169,12 +188,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before while (source.getOffset.isEmpty) { Thread.sleep(10) } - val batch = source.getBatch(None, source.getOffset.get).as[String] - batch.collect() - val numRowsMetric = - batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") - assert(numRowsMetric.nonEmpty) - assert(numRowsMetric.get.value === 1) + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val batch = source.getBatch(None, source.getOffset.get).as[String] + batch.collect() + val numRowsMetric = + batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + assert(numRowsMetric.nonEmpty) + assert(numRowsMetric.get.value === 1) + } source.stop() source = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index a7e32626264c..9a7595eee7bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count +import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -29,7 +35,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("report, verify, getLocation") { withCoordinatorRef(sc) { coordinatorRef => - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) assert(coordinatorRef.getLocation(id) === None) @@ -57,9 +63,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("make inactive") { withCoordinatorRef(sc) { coordinatorRef => - val id1 = StateStoreId("x", 0, 0) - val id2 = StateStoreId("y", 1, 0) - val id3 = StateStoreId("x", 0, 1) + val runId1 = UUID.randomUUID + val runId2 = UUID.randomUUID + val id1 = StateStoreProviderId(StateStoreId("x", 0, 0), runId1) + val id2 = StateStoreProviderId(StateStoreId("y", 1, 0), runId2) + val id3 = StateStoreProviderId(StateStoreId("x", 0, 1), runId1) val host = "hostX" val exec = "exec1" @@ -73,7 +81,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) } - coordinatorRef.deactivateInstances("x") + coordinatorRef.deactivateInstances(runId1) assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) @@ -85,7 +93,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { Some(ExecutorCacheTaskLocation(host, exec).toString)) assert(coordinatorRef.getLocation(id3) === None) - coordinatorRef.deactivateInstances("y") + coordinatorRef.deactivateInstances(runId2) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) assert(coordinatorRef.getLocation(id2) === None) } @@ -95,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { withCoordinatorRef(sc) { coordRef1 => val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) coordRef1.reportActiveInstance(id, "hostX", "exec1") @@ -107,6 +115,45 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } } + + test("query stop deactivates related store providers") { + var coordRef: StateStoreCoordinatorRef = null + try { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + import spark.implicits._ + coordRef = spark.streams.stateStoreCoordinator + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + + // Start a query and run a batch to load state stores + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Verify state store has been loaded + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, 0), query.runId) + assert(coordRef.getLocation(providerId).nonEmpty) + + // Stop and verify whether the stores are deactivated in the coordinator + query.stop() + assert(coordRef.getLocation(providerId).isEmpty) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + if (coordRef != null) coordRef.stop() + StateStore.stop() + } + } } object StateStoreCoordinatorSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index bd197be655d5..defb9ed63a88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -19,32 +19,31 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File import java.nio.file.Files +import java.util.UUID import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.LocalSparkSession._ -import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + import StateStoreTestsHelper._ + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) - private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString private val keySchema = StructType(Seq(StructField("key", StringType, true))) private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) - import StateStoreSuite._ - after { StateStore.stop() } @@ -57,16 +56,15 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val opId = 0 - val rdd1 = - makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( + increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -75,7 +73,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( @@ -84,7 +81,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + sqlContext, operatorStateInfo(path, version = storeVersion), + keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -110,7 +108,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { val resIterator = iter.map { s => val key = stringToRow(s) - val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0) val newValue = oldValue + 1 store.put(key, intToRow(newValue)) (s, newValue) @@ -125,42 +123,49 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn iter: Iterator[String]): Iterator[(String, Option[Int])] = { iter.map { s => val key = stringToRow(s) - val value = store.get(key).map(rowToInt) + val value = Option(store.get(key)).map(rowToInt) (s, value) } } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( + iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( + iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( + iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } } test("preferred locations using StateStoreCoordinator") { quietly { + val queryRunId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 0), queryRunId) + val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 1), queryRunId) + coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1") + coordinatorRef.reportActiveInstance(storeProviderId2, "host2", "exec2") - assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + require( + coordinatorRef.getLocation(storeProviderId1) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + sqlContext, operatorStateInfo(path, queryRunId = queryRunId), + keySchema, valueSchema, None)(increment) require(rdd.partitions.length === 2) assert( @@ -187,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -205,10 +210,17 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } + private def operatorStateInfo( + path: String, + queryRunId: UUID = UUID.randomUUID, + version: Int = 0): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + } + private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) - val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0) store.put(key, intToRow(oldValue + 1)) } store.commit() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index ebb7422765eb..c843b65020d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -33,22 +35,25 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { +class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] + with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ - import StateStoreSuite._ + import StateStoreTestsHelper._ - private val tempDir = Utils.createTempDir().toString - private val keySchema = StructType(Seq(StructField("key", StringType, true))) - private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + val keySchema = StructType(Seq(StructField("key", StringType, true))) + val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) before { StateStore.stop() @@ -60,186 +65,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth require(!StateStore.isMaintenanceRunning) } - test("get, put, remove, commit, and all data iterator") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(provider.latestIterator().isEmpty) - - val store = provider.getStore(0) - assert(!store.hasCommitted) - intercept[IllegalStateException] { - store.iterator() - } - intercept[IllegalStateException] { - store.updates() - } - - // Verify state after updating - put(store, "a", 1) - assert(store.numKeys() === 1) - intercept[IllegalStateException] { - store.iterator() - } - intercept[IllegalStateException] { - store.updates() - } - assert(provider.latestIterator().isEmpty) - - // Make updates, commit and then verify state - put(store, "b", 2) - put(store, "aa", 3) - assert(store.numKeys() === 3) - remove(store, _.startsWith("a")) - assert(store.numKeys() === 1) - assert(store.commit() === 1) - - assert(store.hasCommitted) - assert(rowsToSet(store.iterator()) === Set("b" -> 2)) - assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) - assert(fileExists(provider, version = 1, isSnapshot = false)) - - assert(getDataFromFiles(provider) === Set("b" -> 2)) - - // Trying to get newer versions should fail - intercept[Exception] { - provider.getStore(2) - } - intercept[Exception] { - getDataFromFiles(provider, 2) - } - - // New updates to the reloaded store with new version, and does not change old version - val reloadedProvider = new HDFSBackedStateStoreProvider( - store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) - val reloadedStore = reloadedProvider.getStore(1) - assert(reloadedStore.numKeys() === 1) - put(reloadedStore, "c", 4) - assert(reloadedStore.numKeys() === 2) - assert(reloadedStore.commit() === 2) - assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) - assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) - assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) - assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) - } - - test("filter and concurrent updates") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(provider.latestIterator.isEmpty) - val store = provider.getStore(0) - put(store, "a", 1) - put(store, "b", 2) - - // Updates should work while iterating of filtered entries - val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" } - filtered.foreach { case (keyRow, valueRow) => - store.put(keyRow, intToRow(rowToInt(valueRow) + 1)) - } - assert(get(store, "a") === Some(2)) - - // Removes should work while iterating of filtered entries - val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" } - filtered2.foreach { case (keyRow, _) => - store.remove(keyRow) - } - assert(get(store, "b") === None) - } - - test("updates iterator with all combos of updates and removes") { - val provider = newStoreProvider() - var currentVersion: Int = 0 - - def withStore(body: StateStore => Unit): Unit = { - val store = provider.getStore(currentVersion) - body(store) - currentVersion += 1 - } - - // New data should be seen in updates as value added, even if they had multiple updates - withStore { store => - put(store, "a", 1) - put(store, "aa", 1) - put(store, "aa", 2) - store.commit() - assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) - assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) - } - - // Multiple updates to same key should be collapsed in the updates as a single value update - // Keys that have not been updated should not appear in the updates - withStore { store => - put(store, "a", 4) - put(store, "a", 6) - store.commit() - assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) - assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) - } - - // Keys added, updated and finally removed before commit should not appear in updates - withStore { store => - put(store, "b", 4) // Added, finally removed - put(store, "bb", 5) // Added, updated, finally removed - put(store, "bb", 6) - remove(store, _.startsWith("b")) - store.commit() - assert(updatesToSet(store.updates()) === Set.empty) - assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) - } - - // Removed data should be seen in updates as a key removed - // Removed, but re-added data should be seen in updates as a value update - withStore { store => - remove(store, _.startsWith("a")) - put(store, "a", 10) - store.commit() - assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) - assert(rowsToSet(store.iterator()) === Set("a" -> 10)) - } - } - - test("cancel") { - val provider = newStoreProvider() - val store = provider.getStore(0) - put(store, "a", 1) - store.commit() - assert(rowsToSet(store.iterator()) === Set("a" -> 1)) - - // cancelUpdates should not change the data in the files - val store1 = provider.getStore(1) - put(store1, "b", 1) - store1.abort() - assert(getDataFromFiles(provider) === Set("a" -> 1)) - } - - test("getStore with unexpected versions") { - val provider = newStoreProvider() - - intercept[IllegalArgumentException] { - provider.getStore(-1) - } - - // Prepare some data in the store - val store = provider.getStore(0) - put(store, "a", 1) - assert(store.commit() === 1) - assert(rowsToSet(store.iterator()) === Set("a" -> 1)) - - intercept[IllegalStateException] { - provider.getStore(2) - } - - // Update store version with some data - val store1 = provider.getStore(1) - put(store1, "b", 1) - assert(store1.commit() === 2) - assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) - assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) - } - test("snapshotting") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) var currentVersion = 0 def updateVersionTo(targetVersion: Int): Unit = { @@ -253,9 +80,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } updateVersionTo(2) - require(getDataFromFiles(provider) === Set("a" -> 2)) + require(getData(provider) === Set("a" -> 2)) provider.doMaintenance() // should not generate snapshot files - assert(getDataFromFiles(provider) === Set("a" -> 2)) + assert(getData(provider) === Set("a" -> 2)) for (i <- 1 to currentVersion) { assert(fileExists(provider, i, isSnapshot = false)) // all delta files present @@ -264,22 +91,22 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // After version 6, snapshotting should generate one snapshot file updateVersionTo(6) - require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + require(getData(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) assert(snapshotVersion.nonEmpty, "snapshot file not generated") deleteFilesEarlierThanVersion(provider, snapshotVersion.get) assert( - getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), "snapshotting messed up the data of the snapshotted version") assert( - getDataFromFiles(provider) === Set("a" -> 6), + getData(provider) === Set("a" -> 6), "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files updateVersionTo(20) - require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + require(getData(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot val latestSnapshotVersion = (0 to 20).filter(version => @@ -288,11 +115,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) - assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the data") } test("cleaning") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) for (i <- 1 to 20) { val store = provider.getStore(i - 1) @@ -307,27 +134,27 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted // last couple of versions should be retrievable - assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) - assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + assert(getData(provider, 20) === Set("a" -> 20)) + assert(getData(provider, 19) === Set("a" -> 19)) } test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) - conf.set("fs.default.name", "fake:///") + conf.set("fs.defaultFS", "fake:///") - val provider = newStoreProvider(hadoopConf = conf) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf) provider.getStore(0).commit() provider.getStore(0).commit() // Verify we don't leak temp files - val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation), null, true).asScala.filter(_.getName.startsWith("temp-")) assert(tempFiles.isEmpty) } test("corrupted file handling") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) put(store, "a", i) @@ -338,62 +165,84 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) // Corrupt snapshot file and verify that it throws error - assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion)) corruptFile(provider, snapshotVersion, isSnapshot = true) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion) + getData(provider, snapshotVersion) } // Corrupt delta file and verify that it throws error - assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) corruptFile(provider, snapshotVersion - 1, isSnapshot = false) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion - 1) + getData(provider, snapshotVersion - 1) } // Delete delta file and verify that it throws error deleteFilesEarlierThanVersion(provider, snapshotVersion) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion - 1) + getData(provider, snapshotVersion - 1) } } + test("reports memory usage") { + val provider = newStoreProvider() + val store = provider.getStore(0) + val noDataMemoryUsed = store.metrics.memoryUsedBytes + put(store, "a", 1) + store.commit() + assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) + } + test("StateStore.get") { quietly { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString - val storeId = StateStoreId(dir, 0, 0) + val dir = newDir() + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() - // Verify that trying to get incorrect versions throw errors intercept[IllegalArgumentException] { - StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf) } assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store intercept[IllegalStateException] { - StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) } - // Increase version of the store - val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + // Increase version of the store and try to get again + val store0 = StateStore.get( + storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) assert(store0.version === 0) put(store0, "a", 1) store0.commit() - assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) - assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) + val store1 = StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + assert(store1.version === 1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + + // Verify that you can also load older version + val store0reloaded = StateStore.get( + storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) + assert(store0reloaded.version === 0) + assert(rowsToSet(store0reloaded.iterator()) === Set.empty) // Verify that you can remove the store and still reload and use it StateStore.unload(storeId) assert(!StateStore.isLoaded(storeId)) - val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + val store1reloaded = StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) - put(store1, "a", 2) - assert(store1.commit() === 2) - assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + assert(store1reloaded.version === 1) + put(store1reloaded, "a", 2) + assert(store1reloaded.commit() === 2) + assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2)) } } @@ -407,21 +256,20 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // fails to talk to the StateStoreCoordinator and unloads all the StateStores .set("spark.rpc.numRetries", "1") val opId = 0 - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString - val storeId = StateStoreId(dir, opId, 0) + val dir = newDir() + val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = new HDFSBackedStateStoreProvider( - storeId, keySchema, valueSchema, storeConf, hadoopConf) + val provider = newStoreProvider(storeProviderId.storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get( - storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + val store = StateStore.get(storeProviderId, keySchema, valueSchema, None, + latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() latestStoreVersion += 1 @@ -440,7 +288,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth eventually(timeout(timeoutDuration)) { // Store should have been reported to the coordinator - assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + assert(coordinatorRef.getLocation(storeProviderId).nonEmpty, + "active instance was not reported") // Background maintenance should clean up and generate snapshots assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") @@ -461,33 +310,35 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") } - // If driver decides to deactivate all instances of the store, then this instance - // should be unloaded - coordinatorRef.deactivateInstances(dir) + // If driver decides to deactivate all stores related to a query run, + // then this instance should be unloaded + coordinatorRef.deactivateInstances(storeProviderId.queryRunId) eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeProviderId)) // If some other executor loads the store, then this instance should be unloaded - coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec") eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeProviderId)) } } // Verify if instance is unloaded if SparkContext is stopped eventually(timeout(timeoutDuration)) { require(SparkEnv.get === null) - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) assert(!StateStore.isMaintenanceRunning) } } @@ -495,10 +346,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ - val dir = scheme + "://" + Utils.createDirectory(tempDir, Random.nextString(5)).toURI.getPath + val dir = scheme + "://" + newDir() val conf = new Configuration() conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName) - val provider = newStoreProvider(dir = dir, hadoopConf = conf) + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf) val store = provider.getStore(0) put(store, "a", 0) val e = intercept[IllegalStateException](store.commit()) @@ -506,8 +358,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("SPARK-18416: do not create temp delta file until the store is updated") { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString - val storeId = StateStoreId(dir, 0, 0) + val dir = newDir() + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() val deltaFileDir = new File(s"$dir/0/0/") @@ -533,7 +385,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Getting the store should not create temp file val store0 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 0, storeConf, hadoopConf) } // Put should create a temp file @@ -548,7 +401,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Remove should create a temp file val store1 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 1, storeConf, hadoopConf) } remove(store1, _ == "a") assert(numTempFiles === 1) @@ -561,31 +415,103 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Commit without any updates should create a delta file val store2 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 2, storeConf, hadoopConf) } store2.commit() assert(numTempFiles === 0) assert(numDeltaFiles === 3) } - def getDataFromFiles( - provider: HDFSBackedStateStoreProvider, + test("SPARK-21145: Restarted queries create new provider instances") { + try { + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val spark = SparkSession.builder().master("local[2]").getOrCreate() + SparkSession.setActiveSession(spark) + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + import spark.implicits._ + val inputData = MemoryStream[Int] + + def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) + // stateful query + val query = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + require(query.lastProgress != null) // at least one batch processed after start + val loadedProvidersMethod = + PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]('loadedProviders) + val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod() + val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq } + query.stop() + loadedProviders + } + + val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders() + require(loadedProvidersAfterRun1.length === 1) + + val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders() + assert(loadedProvidersAfterRun2.length === 2) // two providers loaded for 2 runs + + // Both providers should have the same StateStoreId, but the should be different objects + assert(loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId) + assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1)) + + } finally { + SparkSession.getActiveSession.foreach { spark => + spark.streams.active.foreach(_.stop()) + spark.stop() + } + } + } + + override def newStoreProvider(): HDFSBackedStateStoreProvider = { + newStoreProvider(opId = Random.nextInt(), partition = 0) + } + + override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointRootLocation) + } + + override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { + getData(storeProvider) + } + + override def getData( + provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = new HDFSBackedStateStoreProvider( - provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + val reloadedProvider = newStoreProvider(provider.stateStoreId) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { - reloadedProvider.iterator(version).map(rowsToStringInt).toSet + reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet } } - def assertMap( - testMapOption: Option[MapType], - expectedMap: Map[String, Int]): Unit = { - assert(testMapOption.nonEmpty, "no map present") - val convertedMap = testMapOption.get.map(rowsToStringInt) - assert(convertedMap === expectedMap) + def newStoreProvider( + opId: Long, + partition: Int, + dir: String = newDir(), + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + val provider = new HDFSBackedStateStoreProvider() + provider.init( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + indexOrdinal = None, + new StateStoreConf(sqlConf), + hadoopConf) + provider } def fileExists( @@ -622,56 +548,181 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth filePath.delete() filePath.createNewFile() } +} - def storeLoaded(storeId: StateStoreId): Boolean = { - val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) - val loadedStores = StateStore invokePrivate method() - loadedStores.contains(storeId) - } +abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] + extends SparkFunSuite { + import StateStoreTestsHelper._ - def unloadStore(storeId: StateStoreId): Boolean = { - val method = PrivateMethod('remove) - StateStore invokePrivate method(storeId) - } + test("get, put, remove, commit, and all data iterator") { + val provider = newStoreProvider() - def newStoreProvider( - opId: Long = Random.nextLong, - partition: Int = 0, - minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - dir: String = Utils.createDirectory(tempDir, Random.nextString(5)).toString, - hadoopConf: Configuration = new Configuration() - ): HDFSBackedStateStoreProvider = { - val sqlConf = new SQLConf() - sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) - sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) - new HDFSBackedStateStoreProvider( - StateStoreId(dir, opId, partition), - keySchema, - valueSchema, - new StateStoreConf(sqlConf), - hadoopConf) + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + assert(get(store, "a") === None) + assert(store.iterator().isEmpty) + assert(store.metrics.numKeys === 0) + + // Verify state after updating + put(store, "a", 1) + assert(get(store, "a") === Some(1)) + assert(store.metrics.numKeys === 1) + + assert(store.iterator().nonEmpty) + assert(getLatestData(provider).isEmpty) + + // Make updates, commit and then verify state + put(store, "b", 2) + put(store, "aa", 3) + assert(store.metrics.numKeys === 3) + remove(store, _.startsWith("a")) + assert(store.metrics.numKeys === 1) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(getLatestData(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getData(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = newStoreProvider(store.id) + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.metrics.numKeys === 1) + put(reloadedStore, "c", 4) + assert(reloadedStore.metrics.numKeys === 2) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4)) + assert(getData(provider, version = 1) === Set("b" -> 2)) } - def remove(store: StateStore, condition: String => Boolean): Unit = { - store.remove(row => condition(rowToString(row))) + test("removing while iterating") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + val store = provider.getStore(0) + put(store, "a", 1) + put(store, "b", 2) + + // Updates should work while iterating of filtered entries + val filtered = store.iterator.filter { tuple => rowToString(tuple.key) == "a" } + filtered.foreach { tuple => + store.put(tuple.key, intToRow(rowToInt(tuple.value) + 1)) + } + assert(get(store, "a") === Some(2)) + + // Removes should work while iterating of filtered entries + val filtered2 = store.iterator.filter { tuple => rowToString(tuple.key) == "b" } + filtered2.foreach { tuple => store.remove(tuple.key) } + assert(get(store, "b") === None) } - private def put(store: StateStore, key: String, value: Int): Unit = { - store.put(stringToRow(key), intToRow(value)) + test("abort") { + val provider = newStoreProvider() + val store = provider.getStore(0) + put(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + put(store1, "b", 1) + store1.abort() } - private def get(store: StateStore, key: String): Option[Int] = { - store.get(stringToRow(key)).map(rowToInt) + test("getStore with invalid versions") { + val provider = newStoreProvider() + + def checkInvalidVersion(version: Int): Unit = { + intercept[Exception] { + provider.getStore(version) + } + } + + checkInvalidVersion(-1) + checkInvalidVersion(1) + + val store = provider.getStore(0) + put(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + val store1_ = provider.getStore(1) + assert(rowsToSet(store1_.iterator()) === Set("a" -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(2) + + // Update store version with some data + val store1 = provider.getStore(1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + put(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(3) + } + + test("two concurrent StateStores - one for read-only and one for read-write") { + // During Streaming Aggregation, we have two StateStores per task, one used as read-only in + // `StateStoreRestoreExec`, and one read-write used in `StateStoreSaveExec`. `StateStore.abort` + // will be called for these StateStores if they haven't committed their results. We need to + // make sure that `abort` in read-only store after a `commit` in the read-write store doesn't + // accidentally lead to the deletion of state. + val dir = newDir() + val storeId = StateStoreId(dir, 0L, 1) + val provider0 = newStoreProvider(storeId) + // prime state + val store = provider0.getStore(0) + val key = "a" + put(store, key, 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set(key -> 1)) + + // two state stores + val provider1 = newStoreProvider(storeId) + val restoreStore = provider1.getStore(1) + val saveStore = provider1.getStore(1) + + put(saveStore, key, get(restoreStore, key).get + 1) + saveStore.commit() + restoreStore.abort() + + // check that state is correct for next batch + val provider2 = newStoreProvider(storeId) + val finalStore = provider2.getStore(2) + assert(rowsToSet(finalStore.iterator()) === Set(key -> 2)) } -} -private[state] object StateStoreSuite { + /** Return a new provider with a random id */ + def newStoreProvider(): ProviderClass + + /** Return a new provider with the given id */ + def newStoreProvider(storeId: StateStoreId): ProviderClass + + /** Get the latest data referred to by the given provider but not using this provider */ + def getLatestData(storeProvider: ProviderClass): Set[(String, Int)] + + /** + * Get a specific version of data referred to by the given provider but not using + * this provider + */ + def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] +} - /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ - trait TestUpdate - case class Added(key: String, value: Int) extends TestUpdate - case class Updated(key: String, value: Int) extends TestUpdate - case class Removed(key: String) extends TestUpdate +object StateStoreTestsHelper { val strProj = UnsafeProjection.create(Array[DataType](StringType)) val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) @@ -692,26 +743,29 @@ private[state] object StateStoreSuite { row.getInt(0) } - def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { - (rowToInt(row._1), rowToInt(row._2)) + def rowsToStringInt(row: UnsafeRowPair): (String, Int) = { + (rowToString(row.key), rowToInt(row.value)) } + def rowsToSet(iterator: Iterator[UnsafeRowPair]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } - def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { - (rowToString(row._1), rowToInt(row._2)) + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.getRange(None, None).foreach { rowPair => + if (condition(rowToString(rowPair.key))) store.remove(rowPair.key) + } } - def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { - iterator.map(rowsToStringInt).toSet + def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) } - def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { - iterator.map { - case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) - case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) - case ValueRemoved(key, _) => Removed(rowToString(key)) - }.toSet + def get(store: StateStore, key: String): Option[Int] = { + Option(store.get(stringToRow(key))).map(rowToInt) } + + def newDir(): String = Utils.createTempDir().toString } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala new file mode 100644 index 000000000000..ffa4c3c22a19 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types._ + +class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter { + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + + test("SymmetricHashJoinStateManager - all operations") { + withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager => + implicit val mgr = manager + + assert(get(20) === Seq.empty) // initially empty + append(20, 2) + assert(get(20) === Seq(2)) // should first value correctly + assert(numRows === 1) + + append(20, 3) + assert(get(20) === Seq(2, 3)) // should append new values + append(20, 3) + assert(get(20) === Seq(2, 3, 3)) // should append another copy if same value added again + assert(numRows === 3) + + assert(get(30) === Seq.empty) + append(30, 1) + assert(get(30) === Seq(1)) + assert(get(20) === Seq(2, 3, 3)) // add another key-value should not affect existing ones + assert(numRows === 4) + + removeByKey(25) + assert(get(20) === Seq.empty) + assert(get(30) === Seq(1)) // should remove 20, not 30 + assert(numRows === 1) + + removeByKey(30) + assert(get(30) === Seq.empty) // should remove 30 + assert(numRows === 0) + + def appendAndTest(key: Int, values: Int*): Unit = { + values.foreach { value => append(key, value)} + require(get(key) === values) + } + + appendAndTest(40, 100, 200, 300) + appendAndTest(50, 125) + appendAndTest(60, 275) // prepare for testing removeByValue + assert(numRows === 5) + + removeByValue(125) + assert(get(40) === Seq(200, 300)) + assert(get(50) === Seq.empty) + assert(get(60) === Seq(275)) // should remove only some values, not all + assert(numRows === 3) + + append(40, 50) + assert(get(40) === Seq(50, 200, 300)) + assert(numRows === 4) + + removeByValue(200) + assert(get(40) === Seq(300)) + assert(get(60) === Seq(275)) // should remove only some values, not all + assert(numRows === 2) + + removeByValue(300) + assert(get(40) === Seq.empty) + assert(get(60) === Seq.empty) // should remove all values now + assert(numRows === 0) + } + } + val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build() + val inputValueSchema = new StructType() + .add(StructField("time", IntegerType, metadata = watermarkMetadata)) + .add(StructField("value", BooleanType)) + val inputValueAttribs = inputValueSchema.toAttributes + val inputValueAttribWithWatermark = inputValueAttribs(0) + val joinKeyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) + + val inputValueGen = UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray) + val joinKeyGen = UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray) + + + def toInputValue(i: Int): UnsafeRow = { + inputValueGen.apply(new GenericInternalRow(Array[Any](i, false))) + } + + def toJoinKeyRow(i: Int): UnsafeRow = { + joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0))) + } + + def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0) + + def append(key: Int, value: Int)(implicit manager: SymmetricHashJoinStateManager): Unit = { + manager.append(toJoinKeyRow(key), toInputValue(value)) + } + + def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = { + manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted + } + + /** Remove keys (and corresponding values) where `time <= threshold` */ + def removeByKey(threshold: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { + val expr = + LessThanOrEqual( + BoundReference( + 1, inputValueAttribWithWatermark.dataType, inputValueAttribWithWatermark.nullable), + Literal(threshold)) + manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + } + + /** Remove values where `time <= threshold` */ + def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { + val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) + manager.removeByValueCondition( + GeneratePredicate.generate(expr, inputValueAttribs).eval _) + } + + def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { + manager.metrics.numKeys + } + + + def withJoinStateManager( + inputValueAttribs: Seq[Attribute], + joinKeyExprs: Seq[Expression])(f: SymmetricHashJoinStateManager => Unit): Unit = { + + withTempDir { file => + val storeConf = new StateStoreConf() + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0) + val manager = new SymmetricHashJoinStateManager( + LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) + try { + f(manager) + } finally { + manager.abortIfNeeded() + } + } + StateStore.stop() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index e6cd41e4facf..1055f09f5411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -23,6 +23,7 @@ import org.json4s.jackson.JsonMethods._ import org.mockito.Mockito.mock import org.apache.spark._ +import org.apache.spark.LocalSparkContext._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.RDD @@ -394,7 +395,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest } // Listener tracks only SQL metrics, not other accumulators assert(trackedAccums.size === 1) - assert(trackedAccums.head === (sqlMetricInfo.id, sqlMetricInfo.update.get)) + assert(trackedAccums.head === ((sqlMetricInfo.id, sqlMetricInfo.update.get))) } test("driver side SQL metrics") { @@ -496,8 +497,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .setAppName("test") .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly - val sc = new SparkContext(conf) - try { + withSpark(new SparkContext(conf)) { sc => SparkSession.sqlListener.set(null) val spark = new SparkSession(sc) import spark.implicits._ @@ -522,8 +522,6 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { assert(spark.sharedState.listener.executionIdToData.size <= 100) assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100) assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100) - } finally { - sc.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala new file mode 100644 index 000000000000..d24a9e1f4bd1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.arrow.ArrowUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ArrowColumnVectorSuite extends SparkFunSuite { + + test("boolean") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true) + .createVector(allocator).asInstanceOf[NullableBitVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, if (i % 2 == 0) 1 else 0) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === BooleanType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getBoolean(i) === (i % 2 == 0)) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getBooleans(0, 10) === (0 until 10).map(i => (i % 2 == 0))) + + columnVector.close() + allocator.close() + } + + test("byte") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true) + .createVector(allocator).asInstanceOf[NullableTinyIntVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toByte) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === ByteType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getByte(i) === i.toByte) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getBytes(0, 10) === (0 until 10).map(i => i.toByte)) + + columnVector.close() + allocator.close() + } + + test("short") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true) + .createVector(allocator).asInstanceOf[NullableSmallIntVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toShort) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === ShortType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getShort(i) === i.toShort) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getShorts(0, 10) === (0 until 10).map(i => i.toShort)) + + columnVector.close() + allocator.close() + } + + test("int") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true) + .createVector(allocator).asInstanceOf[NullableIntVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === IntegerType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getInt(i) === i) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getInts(0, 10) === (0 until 10)) + + columnVector.close() + allocator.close() + } + + test("long") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("long", LongType, nullable = true) + .createVector(allocator).asInstanceOf[NullableBigIntVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toLong) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === LongType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getLong(i) === i.toLong) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getLongs(0, 10) === (0 until 10).map(i => i.toLong)) + + columnVector.close() + allocator.close() + } + + test("float") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true) + .createVector(allocator).asInstanceOf[NullableFloat4Vector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toFloat) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === FloatType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getFloat(i) === i.toFloat) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getFloats(0, 10) === (0 until 10).map(i => i.toFloat)) + + columnVector.close() + allocator.close() + } + + test("double") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true) + .createVector(allocator).asInstanceOf[NullableFloat8Vector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toDouble) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === DoubleType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getDouble(i) === i.toDouble) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getDoubles(0, 10) === (0 until 10).map(i => i.toDouble)) + + columnVector.close() + allocator.close() + } + + test("string") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("string", StringType, nullable = true) + .createVector(allocator).asInstanceOf[NullableVarCharVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + mutator.setSafe(i, utf8, 0, utf8.length) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === StringType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("binary") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true) + .createVector(allocator).asInstanceOf[NullableVarBinaryVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + mutator.setSafe(i, utf8, 0, utf8.length) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === BinaryType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("array") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true) + .createVector(allocator).asInstanceOf[ListVector] + vector.allocateNew() + val mutator = vector.getMutator() + val elementVector = vector.getDataVector().asInstanceOf[NullableIntVector] + val elementMutator = elementVector.getMutator() + + // [1, 2] + mutator.startNewValue(0) + elementMutator.setSafe(0, 1) + elementMutator.setSafe(1, 2) + mutator.endValue(0, 2) + + // [3, null, 5] + mutator.startNewValue(1) + elementMutator.setSafe(2, 3) + elementMutator.setNull(3) + elementMutator.setSafe(4, 5) + mutator.endValue(1, 3) + + // null + + // [] + mutator.startNewValue(3) + mutator.endValue(3, 0) + + elementMutator.setValueCount(5) + mutator.setValueCount(4) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === ArrayType(IntegerType)) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + val array0 = columnVector.getArray(0) + assert(array0.numElements() === 2) + assert(array0.getInt(0) === 1) + assert(array0.getInt(1) === 2) + + val array1 = columnVector.getArray(1) + assert(array1.numElements() === 3) + assert(array1.getInt(0) === 3) + assert(array1.isNullAt(1)) + assert(array1.getInt(2) === 5) + + assert(columnVector.isNullAt(2)) + + val array3 = columnVector.getArray(3) + assert(array3.numElements() === 0) + + columnVector.close() + allocator.close() + } + + test("struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = true) + .createVector(allocator).asInstanceOf[NullableMapVector] + vector.allocateNew() + val mutator = vector.getMutator() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[NullableIntVector] + val intMutator = intVector.getMutator() + val longVector = vector.getChildByOrdinal(1).asInstanceOf[NullableBigIntVector] + val longMutator = longVector.getMutator() + + // (1, 1L) + mutator.setIndexDefined(0) + intMutator.setSafe(0, 1) + longMutator.setSafe(0, 1L) + + // (2, null) + mutator.setIndexDefined(1) + intMutator.setSafe(1, 2) + longMutator.setNull(1) + + // (null, 3L) + mutator.setIndexDefined(2) + intMutator.setNull(2) + longMutator.setSafe(2, 3L) + + // null + mutator.setNull(3) + + // (5, 5L) + mutator.setIndexDefined(4) + intMutator.setSafe(4, 5) + longMutator.setSafe(4, 5L) + + intMutator.setValueCount(5) + longMutator.setValueCount(5) + mutator.setValueCount(5) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + val row0 = columnVector.getStruct(0, 2) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1, 2) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + val row2 = columnVector.getStruct(2, 2) + assert(row2.isNullAt(0)) + assert(row2.getLong(1) === 3L) + + assert(columnVector.isNullAt(3)) + + val row4 = columnVector.getStruct(4, 2) + assert(row4.getInt(0) === 5) + assert(row4.getLong(1) === 5L) + + columnVector.close() + allocator.close() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala new file mode 100644 index 000000000000..85da8270d4cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { + private def withVector( + vector: WritableColumnVector)( + block: WritableColumnVector => Unit): Unit = { + try block(vector) finally vector.close() + } + + private def testVectors( + name: String, + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + test(name) { + withVector(new OnHeapColumnVector(size, dt))(block) + withVector(new OffHeapColumnVector(size, dt))(block) + } + } + + testVectors("boolean", 10, BooleanType) { testVector => + (0 until 10).foreach { i => + testVector.appendBoolean(i % 2 == 0) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, BooleanType) === (i % 2 == 0)) + } + } + + testVectors("byte", 10, ByteType) { testVector => + (0 until 10).foreach { i => + testVector.appendByte(i.toByte) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, ByteType) === i.toByte) + } + } + + testVectors("short", 10, ShortType) { testVector => + (0 until 10).foreach { i => + testVector.appendShort(i.toShort) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, ShortType) === i.toShort) + } + } + + testVectors("int", 10, IntegerType) { testVector => + (0 until 10).foreach { i => + testVector.appendInt(i) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, IntegerType) === i) + } + } + + testVectors("long", 10, LongType) { testVector => + (0 until 10).foreach { i => + testVector.appendLong(i) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, LongType) === i) + } + } + + testVectors("float", 10, FloatType) { testVector => + (0 until 10).foreach { i => + testVector.appendFloat(i.toFloat) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, FloatType) === i.toFloat) + } + } + + testVectors("double", 10, DoubleType) { testVector => + (0 until 10).foreach { i => + testVector.appendDouble(i.toDouble) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, DoubleType) === i.toDouble) + } + } + + testVectors("string", 10, StringType) { testVector => + (0 until 10).map { i => + val utf8 = s"str$i".getBytes("utf8") + testVector.appendByteArray(utf8, 0, utf8.length) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, StringType) === UTF8String.fromString(s"str$i")) + } + } + + testVectors("binary", 10, BinaryType) { testVector => + (0 until 10).map { i => + val utf8 = s"str$i".getBytes("utf8") + testVector.appendByteArray(utf8, 0, utf8.length) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + assert(array.get(i, BinaryType) === utf8) + } + } + + val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true) + testVectors("array", 10, arrayType) { testVector => + + val data = testVector.arrayData() + var i = 0 + while (i < 6) { + data.putInt(i, i) + i += 1 + } + + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + testVector.putArray(0, 0, 1) + testVector.putArray(1, 1, 2) + testVector.putArray(2, 3, 0) + testVector.putArray(3, 3, 3) + + val array = new ColumnVector.Array(testVector) + + assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(0)) + assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(1, 2)) + assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int]) + assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) + } + + val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) + testVectors("struct", 10, structType) { testVector => + val c1 = testVector.getChildColumn(0) + val c2 = testVector.getChildColumn(1) + c1.putInt(0, 123) + c2.putDouble(0, 3.45) + c1.putInt(1, 456) + c2.putDouble(1, 5.67) + + val array = new ColumnVector.Array(testVector) + + assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) + assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) + assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) + assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + } + + test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { + withVector(new OffHeapColumnVector(8, arrayType)) { testVector => + val data = testVector.arrayData() + (0 until 8).foreach(i => data.putInt(i, i)) + (0 until 8).foreach(i => testVector.putArray(i, i, 1)) + + // Increase vector's capacity and reallocate the data to new bigger buffers. + testVector.reserve(16) + + // Check that none of the values got lost/overwritten. + val array = new ColumnVector.Array(testVector) + (0 until 8).foreach { i => + assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + } + } + } + + test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") { + withVector(new OffHeapColumnVector(8, structType)) { testVector => + (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) + testVector.reserve(16) + (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 67b3d98c1dae..1331f157363b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -24,7 +24,10 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.ColumnVector -import org.apache.spark.sql.types.{BinaryType, IntegerType} +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.execution.vectorized.WritableColumnVector +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet @@ -34,6 +37,14 @@ import org.apache.spark.util.collection.BitSet */ object ColumnarBatchBenchmark { + def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + if (memMode == MemoryMode.OFF_HEAP) { + new OffHeapColumnVector(capacity, dt) + } else { + new OnHeapColumnVector(capacity, dt) + } + } + // This benchmark reads and writes an array of ints. // TODO: there is a big (2x) penalty for a random access API for off heap. // Note: carefully if modifying this code. It's hard to reason about the JIT. @@ -140,7 +151,7 @@ object ColumnarBatchBenchmark { // Access through the column API with on heap memory val columnOnHeap = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -159,7 +170,7 @@ object ColumnarBatchBenchmark { // Access through the column API with off heap memory def columnOffHeap = { i: Int => { - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -178,7 +189,7 @@ object ColumnarBatchBenchmark { // Access by directly getting the buffer backing the column. val columnOffheapDirect = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) var sum = 0L for (n <- 0L until iters) { var addr = col.valuesNativeAddress() @@ -244,7 +255,7 @@ object ColumnarBatchBenchmark { // Adding values by appending, instead of putting. val onHeapAppend = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -362,7 +373,7 @@ object ColumnarBatchBenchmark { .map(_.getBytes(StandardCharsets.UTF_8)).toArray def column(memoryMode: MemoryMode) = { i: Int => - val column = ColumnVector.allocate(count, BinaryType, memoryMode) + val column = allocate(count, BinaryType, memoryMode) var sum = 0L for (n <- 0L until iters) { var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8184d7d909f4..983eb103682c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -17,48 +17,98 @@ package org.apache.spark.sql.execution.vectorized -import java.nio.charset.StandardCharsets import java.nio.ByteBuffer import java.nio.ByteOrder +import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random +import org.apache.arrow.vector.NullableIntVector + import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { - test("Null Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val reference = mutable.ArrayBuffer.empty[Boolean] - val column = ColumnVector.allocate(1024, IntegerType, memMode) + private def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + if (memMode == MemoryMode.OFF_HEAP) { + new OffHeapColumnVector(capacity, dt) + } else { + new OnHeapColumnVector(capacity, dt) + } + } + + private def testVector( + name: String, + size: Int, + dt: DataType)( + block: (WritableColumnVector, MemoryMode) => Unit): Unit = { + test(name) { + Seq(MemoryMode.ON_HEAP, MemoryMode.OFF_HEAP).foreach { mode => + val vector = allocate(size, dt, mode) + try block(vector, mode) finally { + vector.close() + } + } + } + } + + testVector("Null APIs", 1024, IntegerType) { + (column, memMode) => + val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) + assert(column.numNulls() == 0) + + column.appendNotNull() + reference += false + assert(!column.anyNullsSet()) + assert(column.numNulls() == 0) + + column.appendNotNulls(3) + (1 to 3).foreach(_ => reference += false) + assert(!column.anyNullsSet()) + assert(column.numNulls() == 0) + + column.appendNull() + reference += true + assert(column.anyNullsSet()) + assert(column.numNulls() == 1) + + column.appendNulls(3) + (1 to 3).foreach(_ => reference += true) + assert(column.anyNullsSet()) + assert(column.numNulls() == 4) + + idx = column.elementsAppended column.putNotNull(idx) reference += false idx += 1 - assert(column.anyNullsSet() == false) + assert(column.anyNullsSet()) + assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 - assert(column.anyNullsSet() == true) - assert(column.numNulls() == 1) + assert(column.anyNullsSet()) + assert(column.numNulls() == 5) column.putNulls(idx, 3) reference += true reference += true reference += true idx += 3 - assert(column.anyNullsSet() == true) + assert(column.anyNullsSet()) + assert(column.numNulls() == 8) column.putNotNulls(idx, 4) reference += false @@ -66,8 +116,8 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 - assert(column.anyNullsSet() == true) - assert(column.numNulls() == 4) + assert(column.anyNullsSet()) + assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => assert(v._1 == column.isNullAt(v._2)) @@ -76,18 +126,31 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) } } - column.close - }} } - test("Byte Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Byte APIs", 1024, ByteType) { + (column, memMode) => val reference = mutable.ArrayBuffer.empty[Byte] - val column = ColumnVector.allocate(1024, ByteType, memMode) - var idx = 0 + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray + column.appendBytes(2, values, 0) + reference += 10.toByte + reference += 20.toByte + + column.appendBytes(3, values, 2) + reference += 30.toByte + reference += 40.toByte + reference += 50.toByte + + column.appendBytes(6, 60.toByte) + (1 to 6).foreach(_ => reference += 60.toByte) + + column.appendByte(70.toByte) + reference += 70.toByte + + var idx = column.elementsAppended - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray column.putBytes(idx, 2, values, 0) reference += 1 reference += 2 @@ -116,19 +179,33 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getByte(null, addr + v._2)) } } - }} } - test("Short Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Short APIs", 1024, ShortType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Short] - val column = ColumnVector.allocate(1024, ShortType, memMode) - var idx = 0 + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray + column.appendShorts(2, values, 0) + reference += 10.toShort + reference += 20.toShort + + column.appendShorts(3, values, 2) + reference += 30.toShort + reference += 40.toShort + reference += 50.toShort - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toShort).toArray + column.appendShorts(6, 60.toShort) + (1 to 6).foreach(_ => reference += 60.toShort) + + column.appendShort(70.toShort) + reference += 70.toShort + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toShort).toArray column.putShorts(idx, 2, values, 0) reference += 1 reference += 2 @@ -177,21 +254,33 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getShort(null, addr + 2 * v._2)) } } - - column.close - }} } - test("Int Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Int APIs", 1024, IntegerType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] - val column = ColumnVector.allocate(1024, IntegerType, memMode) - var idx = 0 + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray + column.appendInts(2, values, 0) + reference += 10 + reference += 20 + + column.appendInts(3, values, 2) + reference += 30 + reference += 40 + reference += 50 + + column.appendInts(6, 60) + (1 to 6).foreach(_ => reference += 60) + + column.appendInt(70) + reference += 70 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray column.putInts(idx, 2, values, 0) reference += 1 reference += 2 @@ -246,20 +335,33 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) } } - column.close - }} } - test("Long Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Long APIs", 1024, LongType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Long] - val column = ColumnVector.allocate(1024, LongType, memMode) - var idx = 0 + var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray + column.appendLongs(2, values, 0) + reference += 10L + reference += 20L + + column.appendLongs(3, values, 2) + reference += 30L + reference += 40L + reference += 50L - val values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray + column.appendLongs(6, 60L) + (1 to 6).foreach(_ => reference += 60L) + + column.appendLong(70L) + reference += 70L + + var idx = column.elementsAppended + + values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray column.putLongs(idx, 2, values, 0) reference += 1 reference += 2 @@ -317,19 +419,120 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) } } - }} } - test("Double APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Float APIs", 1024, FloatType) { + (column, memMode) => + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Float] + + var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray + column.appendFloats(2, values, 0) + reference += .1f + reference += .2f + + column.appendFloats(3, values, 2) + reference += .3f + reference += .4f + reference += .5f + + column.appendFloats(6, .6f) + (1 to 6).foreach(_ => reference += .6f) + + column.appendFloat(.7f) + reference += .7f + + var idx = column.elementsAppended + + values = (1.0f :: 2.0f :: 3.0f :: 4.0f :: 5.0f :: Nil).toArray + column.putFloats(idx, 2, values, 0) + reference += 1.0f + reference += 2.0f + idx += 2 + + column.putFloats(idx, 3, values, 2) + reference += 3.0f + reference += 4.0f + reference += 5.0f + idx += 3 + + val buffer = new Array[Byte](8) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET, 2.234f) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET + 4, 1.123f) + + if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { + // Ensure array contains Little Endian floats + val bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getFloat(0)) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET + 4, bb.getFloat(4)) + } + + column.putFloats(idx, 1, buffer, 4) + column.putFloats(idx + 1, 1, buffer, 0) + reference += 1.123f + reference += 2.234f + idx += 2 + + column.putFloats(idx, 2, buffer, 0) + reference += 2.234f + reference += 1.123f + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextFloat() + column.putFloat(idx, v) + reference += v + idx += 1 + } else { + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + val v = random.nextFloat() + column.putFloats(idx, n, v) + var i = 0 + while (i < n) { + reference += v + i += 1 + } + idx += n + } + } + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getFloat(v._2), "Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) + } + } + } + + testVector("Double APIs", 1024, DoubleType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] - val column = ColumnVector.allocate(1024, DoubleType, memMode) - var idx = 0 + var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray + column.appendDoubles(2, values, 0) + reference += .1 + reference += .2 + + column.appendDoubles(3, values, 2) + reference += .3 + reference += .4 + reference += .5 + + column.appendDoubles(6, .6) + (1 to 6).foreach(_ => reference += .6) + + column.appendDouble(.7) + reference += .7 - val values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray + var idx = column.elementsAppended + + values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray column.putDoubles(idx, 2, values, 0) reference += 1.0 reference += 2.0 @@ -346,8 +549,8 @@ class ColumnarBatchSuite extends SparkFunSuite { Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, 1.123) if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { - // Ensure array contains Liitle Endian doubles - var bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + // Ensure array contains Little Endian doubles + val bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getDouble(0)) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, bb.getDouble(8)) } @@ -390,50 +593,54 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) } } - column.close - }} } - test("String APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("String APIs", 6, StringType) { + (column, memMode) => val reference = mutable.ArrayBuffer.empty[String] - val column = ColumnVector.allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) - var idx = 0 + + val str = "string" + column.appendByteArray(str.getBytes(StandardCharsets.UTF_8), + 0, str.getBytes(StandardCharsets.UTF_8).length) + reference += str + assert(column.arrayData().elementsAppended == 6) + + var idx = column.elementsAppended val values = ("Hello" :: "abc" :: Nil).toArray column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), 0, values(0).getBytes(StandardCharsets.UTF_8).length) reference += values(0) idx += 1 - assert(column.arrayData().elementsAppended == 5) + assert(column.arrayData().elementsAppended == 11) column.putByteArray(idx, values(1).getBytes(StandardCharsets.UTF_8), 0, values(1).getBytes(StandardCharsets.UTF_8).length) reference += values(1) idx += 1 - assert(column.arrayData().elementsAppended == 8) + assert(column.arrayData().elementsAppended == 14) // Just put llo val offset = column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), 2, values(0).getBytes(StandardCharsets.UTF_8).length - 2) reference += "llo" idx += 1 - assert(column.arrayData().elementsAppended == 11) + assert(column.arrayData().elementsAppended == 17) // Put the same "ll" at offset. This should not allocate more memory in the column. column.putArray(idx, offset, 2) reference += "ll" idx += 1 - assert(column.arrayData().elementsAppended == 11) + assert(column.arrayData().elementsAppended == 17) // Put a long string val s = "abcdefghijklmnopqrstuvwxyz" column.putByteArray(idx, (s + s).getBytes(StandardCharsets.UTF_8)) reference += (s + s) idx += 1 - assert(column.arrayData().elementsAppended == 11 + (s + s).length) + assert(column.arrayData().elementsAppended == 17 + (s + s).length) reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) @@ -443,15 +650,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.reset() assert(column.arrayData().elementsAppended == 0) - }} } - test("Int Array") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { + (column, _) => // Fill the underlying data with all the arrays back to back. - val data = column.arrayData(); + val data = column.arrayData() var i = 0 while (i < 6) { data.putInt(i, i) @@ -489,7 +694,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(3).getInt(2) == 5) // Add a longer array which requires resizing - column.reset + column.reset() val array = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) assert(data.capacity == 10) data.reserve(array.length) @@ -498,14 +703,67 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, array.length) assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] === array) - }} } - test("Struct Column") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - val column = ColumnVector.allocate(1024, schema, memMode) + test("toArray for primitive types") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val len = 4 + + val columnBool = allocate(len, new ArrayType(BooleanType, false), memMode) + val boolArray = Array(false, true, false, true) + boolArray.zipWithIndex.foreach { case (v, i) => columnBool.arrayData.putBoolean(i, v) } + columnBool.putArray(0, 0, len) + assert(columnBool.getArray(0).toBooleanArray === boolArray) + columnBool.close() + + val columnByte = allocate(len, new ArrayType(ByteType, false), memMode) + val byteArray = Array[Byte](0, 1, 2, 3) + byteArray.zipWithIndex.foreach { case (v, i) => columnByte.arrayData.putByte(i, v) } + columnByte.putArray(0, 0, len) + assert(columnByte.getArray(0).toByteArray === byteArray) + columnByte.close() + + val columnShort = allocate(len, new ArrayType(ShortType, false), memMode) + val shortArray = Array[Short](0, 1, 2, 3) + shortArray.zipWithIndex.foreach { case (v, i) => columnShort.arrayData.putShort(i, v) } + columnShort.putArray(0, 0, len) + assert(columnShort.getArray(0).toShortArray === shortArray) + columnShort.close() + + val columnInt = allocate(len, new ArrayType(IntegerType, false), memMode) + val intArray = Array(0, 1, 2, 3) + intArray.zipWithIndex.foreach { case (v, i) => columnInt.arrayData.putInt(i, v) } + columnInt.putArray(0, 0, len) + assert(columnInt.getArray(0).toIntArray === intArray) + columnInt.close() + + val columnLong = allocate(len, new ArrayType(LongType, false), memMode) + val longArray = Array[Long](0, 1, 2, 3) + longArray.zipWithIndex.foreach { case (v, i) => columnLong.arrayData.putLong(i, v) } + columnLong.putArray(0, 0, len) + assert(columnLong.getArray(0).toLongArray === longArray) + columnLong.close() + + val columnFloat = allocate(len, new ArrayType(FloatType, false), memMode) + val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F) + floatArray.zipWithIndex.foreach { case (v, i) => columnFloat.arrayData.putFloat(i, v) } + columnFloat.putArray(0, 0, len) + assert(columnFloat.getArray(0).toFloatArray === floatArray) + columnFloat.close() + + val columnDouble = allocate(len, new ArrayType(DoubleType, false), memMode) + val doubleArray = Array(0.0, 1.1, 2.2, 3.3) + doubleArray.zipWithIndex.foreach { case (v, i) => columnDouble.arrayData.putDouble(i, v) } + columnDouble.putArray(0, 0, len) + assert(columnDouble.getArray(0).toDoubleArray === doubleArray) + columnDouble.close() + } + } + testVector( + "Struct Column", + 10, + new StructType().add("int", IntegerType).add("double", DoubleType)) { (column, _) => val c1 = column.getChildColumn(0) val c2 = column.getChildColumn(1) assert(c1.dataType() == IntegerType) @@ -528,7 +786,122 @@ class ColumnarBatchSuite extends SparkFunSuite { val s2 = column.getStruct(1) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) - }} + } + + testVector("Nest Array in Array", 10, new ArrayType(new ArrayType(IntegerType, true), true)) { + (column, _) => + val childColumn = column.arrayData() + val data = column.arrayData().arrayData() + (0 until 6).foreach { + case 3 => data.putNull(3) + case i => data.putInt(i, i) + } + // Arrays in child column: [0], [1, 2], [], [null, 4, 5] + childColumn.putArray(0, 0, 1) + childColumn.putArray(1, 1, 2) + childColumn.putArray(2, 2, 0) + childColumn.putArray(3, 3, 3) + // Arrays in column: [[0]], [[1, 2], []], [[], [null, 4, 5]], null + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putArray(2, 2, 2) + column.putNull(3) + + assert(column.getArray(0).getArray(0).toIntArray() === Array(0)) + assert(column.getArray(1).getArray(0).toIntArray() === Array(1, 2)) + assert(column.getArray(1).getArray(1).toIntArray() === Array()) + assert(column.getArray(2).getArray(0).toIntArray() === Array()) + assert(column.getArray(2).getArray(1).isNullAt(0)) + assert(column.getArray(2).getArray(1).getInt(1) === 4) + assert(column.getArray(2).getArray(1).getInt(2) === 5) + assert(column.isNullAt(3)) + } + + private val structType: StructType = new StructType().add("i", IntegerType).add("l", LongType) + + testVector( + "Nest Struct in Array", + 10, + new ArrayType(structType, true)) { (column, _) => + val data = column.arrayData() + val c0 = data.getChildColumn(0) + val c1 = data.getChildColumn(1) + // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) + (0 until 6).foreach { i => + c0.putInt(i, i) + c1.putLong(i, i * 10) + } + // Arrays in column: [(0, 0), (1, 10)], [(1, 10), (2, 20), (3, 30)], + // [(4, 40), (5, 50)] + column.putArray(0, 0, 2) + column.putArray(1, 1, 3) + column.putArray(2, 4, 2) + + assert(column.getArray(0).getStruct(0, 2).toSeq(structType) === Seq(0, 0)) + assert(column.getArray(0).getStruct(1, 2).toSeq(structType) === Seq(1, 10)) + assert(column.getArray(1).getStruct(0, 2).toSeq(structType) === Seq(1, 10)) + assert(column.getArray(1).getStruct(1, 2).toSeq(structType) === Seq(2, 20)) + assert(column.getArray(1).getStruct(2, 2).toSeq(structType) === Seq(3, 30)) + assert(column.getArray(2).getStruct(0, 2).toSeq(structType) === Seq(4, 40)) + assert(column.getArray(2).getStruct(1, 2).toSeq(structType) === Seq(5, 50)) + } + + testVector( + "Nest Array in Struct", + 10, + new StructType() + .add("int", IntegerType) + .add("array", new ArrayType(IntegerType, true))) { (column, _) => + val c0 = column.getChildColumn(0) + val c1 = column.getChildColumn(1) + c0.putInt(0, 0) + c0.putInt(1, 1) + c0.putInt(2, 2) + val c1Child = c1.arrayData() + (0 until 6).foreach { i => + c1Child.putInt(i, i) + } + // Arrays in c1: [0, 1], [2], [3, 4, 5] + c1.putArray(0, 0, 2) + c1.putArray(1, 2, 1) + c1.putArray(2, 3, 3) + + assert(column.getStruct(0).getInt(0) === 0) + assert(column.getStruct(0).getArray(1).toIntArray() === Array(0, 1)) + assert(column.getStruct(1).getInt(0) === 1) + assert(column.getStruct(1).getArray(1).toIntArray() === Array(2)) + assert(column.getStruct(2).getInt(0) === 2) + assert(column.getStruct(2).getArray(1).toIntArray() === Array(3, 4, 5)) + } + + private val subSchema: StructType = new StructType() + .add("int", IntegerType) + .add("int", IntegerType) + testVector( + "Nest Struct in Struct", + 10, + new StructType().add("int", IntegerType).add("struct", subSchema)) { (column, _) => + val c0 = column.getChildColumn(0) + val c1 = column.getChildColumn(1) + c0.putInt(0, 0) + c0.putInt(1, 1) + c0.putInt(2, 2) + val c1c0 = c1.getChildColumn(0) + val c1c1 = c1.getChildColumn(1) + // Structs in c1: (7, 70), (8, 80), (9, 90) + c1c0.putInt(0, 7) + c1c0.putInt(1, 8) + c1c0.putInt(2, 9) + c1c1.putInt(0, 70) + c1c1.putInt(1, 80) + c1c1.putInt(2, 90) + + assert(column.getStruct(0).getInt(0) === 0) + assert(column.getStruct(0).getStruct(1, 2).toSeq(subSchema) === Seq(7, 70)) + assert(column.getStruct(1).getInt(0) === 1) + assert(column.getStruct(1).getStruct(1, 2).toSeq(subSchema) === Seq(8, 80)) + assert(column.getStruct(2).getInt(0) === 2) + assert(column.getStruct(2).getStruct(1, 2).toSeq(subSchema) === Seq(9, 90)) } test("ColumnarBatch basic") { @@ -539,7 +912,11 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("intCol2", IntegerType) .add("string", BinaryType) - val batch = ColumnarBatch.allocate(schema, memMode) + val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE + val columns = schema.fields.map { field => + allocate(capacity, field.dataType, memMode) + } + val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) assert(batch.numCols() == 4) assert(batch.numRows() == 0) assert(batch.numValidRows() == 0) @@ -547,10 +924,10 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == false) // Add a row [1, 1.1, NULL] - batch.column(0).putInt(0, 1) - batch.column(1).putDouble(0, 1.1) - batch.column(2).putNull(0) - batch.column(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8)) + columns(0).putInt(0, 1) + columns(1).putDouble(0, 1.1) + columns(2).putNull(0) + columns(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8)) batch.setNumRows(1) // Verify the results of the row. @@ -560,12 +937,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == true) assert(batch.rowIterator().hasNext == true) - assert(batch.column(0).getInt(0) == 1) - assert(batch.column(0).isNullAt(0) == false) - assert(batch.column(1).getDouble(0) == 1.1) - assert(batch.column(1).isNullAt(0) == false) - assert(batch.column(2).isNullAt(0) == true) - assert(batch.column(3).getUTF8String(0).toString == "Hello") + assert(columns(0).getInt(0) == 1) + assert(columns(0).isNullAt(0) == false) + assert(columns(1).getDouble(0) == 1.1) + assert(columns(1).isNullAt(0) == false) + assert(columns(2).isNullAt(0) == true) + assert(columns(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -576,7 +953,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) - assert(batch.column(3).getUTF8String(0).toString == "Hello") + assert(columns(3).getUTF8String(0).toString == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) @@ -593,20 +970,20 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == false) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] - batch.column(0).putNull(0) - batch.column(1).putDouble(0, 2.2) - batch.column(2).putInt(0, 2) - batch.column(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8)) - - batch.column(0).putInt(1, 3) - batch.column(1).putNull(1) - batch.column(2).putInt(1, 3) - batch.column(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8)) - - batch.column(0).putInt(2, 4) - batch.column(1).putDouble(2, 4.4) - batch.column(2).putInt(2, 4) - batch.column(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8)) + columns(0).putNull(0) + columns(1).putDouble(0, 2.2) + columns(2).putInt(0, 2) + columns(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8)) + + columns(0).putInt(1, 3) + columns(1).putNull(1) + columns(2).putInt(1, 3) + columns(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8)) + + columns(0).putInt(2, 4) + columns(1).putDouble(2, 4.4) + columns(2).putInt(2, 4) + columns(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8)) batch.setNumRows(3) def rowEquals(x: InternalRow, y: Row): Unit = { @@ -645,7 +1022,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val it4 = batch.rowIterator() rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) - batch.close + batch.close() }} } @@ -853,7 +1230,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("exceeding maximum capacity should throw an error") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val column = ColumnVector.allocate(1, ByteType, memMode) + val column = allocate(1, ByteType, memMode) column.MAX_CAPACITY = 15 column.appendBytes(5, 0.toByte) // Successfully allocate twice the requested capacity @@ -869,4 +1246,51 @@ class ColumnarBatchSuite extends SparkFunSuite { s"vectorized reader")) } } + + test("create columnar batch from Arrow column vectors") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) + val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true) + .createVector(allocator).asInstanceOf[NullableIntVector] + vector1.allocateNew() + val mutator1 = vector1.getMutator() + val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true) + .createVector(allocator).asInstanceOf[NullableIntVector] + vector2.allocateNew() + val mutator2 = vector2.getMutator() + + (0 until 10).foreach { i => + mutator1.setSafe(i, i) + mutator2.setSafe(i + 1, i) + } + mutator1.setNull(10) + mutator1.setValueCount(11) + mutator2.setNull(0) + mutator2.setValueCount(11) + + val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) + + val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType))) + val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11) + batch.setNumRows(11) + + assert(batch.numCols() == 2) + assert(batch.numRows() == 11) + + val rowIter = batch.rowIterator().asScala + rowIter.zipWithIndex.foreach { case (row, i) => + if (i == 10) { + assert(row.isNullAt(0)) + } else { + assert(row.getInt(0) == i) + } + if (i == 0) { + assert(row.isNullAt(1)) + } else { + assert(row.getInt(1) == i - 1) + } + } + + batch.close() + allocator.close() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala index d826d3f54d92..f65dcdf119c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala @@ -27,7 +27,7 @@ class ReduceAggregatorSuite extends SparkFunSuite { val encoder: ExpressionEncoder[Int] = ExpressionEncoder() val func = (v1: Int, v2: Int) => v1 + v2 val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) - assert(aggregator.zero == (false, null)) + assert(aggregator.zero == (false, null).asInstanceOf[(Boolean, Int)]) } test("reduce, merge and finish") { @@ -36,22 +36,22 @@ class ReduceAggregatorSuite extends SparkFunSuite { val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) val firstReduce = aggregator.reduce(aggregator.zero, 1) - assert(firstReduce == (true, 1)) + assert(firstReduce == ((true, 1))) val secondReduce = aggregator.reduce(firstReduce, 2) - assert(secondReduce == (true, 3)) + assert(secondReduce == ((true, 3))) val thirdReduce = aggregator.reduce(secondReduce, 3) - assert(thirdReduce == (true, 6)) + assert(thirdReduce == ((true, 6))) val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce) - assert(mergeWithZero1 == (true, 1)) + assert(mergeWithZero1 == ((true, 1))) val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero) - assert(mergeWithZero2 == (true, 3)) + assert(mergeWithZero2 == ((true, 3))) val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce) - assert(mergeTwoReduced == (true, 4)) + assert(mergeTwoReduced == ((true, 4))) assert(aggregator.finish(firstReduce)== 1) assert(aggregator.finish(secondReduce) == 3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 8f9c52cb1e03..6acac1a9aa31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel /** @@ -78,7 +79,7 @@ class CatalogSuite val tempFunc = (e: Seq[Expression]) => e.head val funcMeta = CatalogFunction(FunctionIdentifier(name, None), "className", Nil) sessionCatalog.registerFunction( - funcMeta, ignoreIfExists = false, functionBuilder = Some(tempFunc)) + funcMeta, overrideIfExists = false, functionBuilder = Some(tempFunc)) } private def dropFunction(name: String, db: Option[String] = None): Unit = { @@ -366,6 +367,7 @@ class CatalogSuite withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { // Try to find non existing functions. intercept[AnalysisException](spark.catalog.getFunction("fn1")) + intercept[AnalysisException](spark.catalog.getFunction(db, "fn1")) intercept[AnalysisException](spark.catalog.getFunction("fn2")) intercept[AnalysisException](spark.catalog.getFunction(db, "fn2")) @@ -378,6 +380,8 @@ class CatalogSuite assert(fn1.name === "fn1") assert(fn1.database === null) assert(fn1.isTemporary) + // Find a temporary function with database + intercept[AnalysisException](spark.catalog.getFunction(db, "fn1")) // Find a qualified function val fn2 = spark.catalog.getFunction(db, "fn2") @@ -454,6 +458,7 @@ class CatalogSuite // Find a temporary function assert(spark.catalog.functionExists("fn1")) + assert(!spark.catalog.functionExists(db, "fn1")) // Find a qualified function assert(spark.catalog.functionExists(db, "fn2")) @@ -535,4 +540,11 @@ class CatalogSuite .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true) assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) } + + test("cacheTable with storage level") { + createTempTable("my_temp_table") + spark.catalog.cacheTable("my_temp_table", StorageLevel.DISK_ONLY) + assert(spark.table("my_temp_table").storageLevel == StorageLevel.DISK_ONLY) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index f2456c770406..135370bd1d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -37,6 +37,9 @@ class SQLConfEntrySuite extends SparkFunSuite { assert(conf.getConfString(key) === "20") assert(conf.getConf(confEntry, 5) === 20) + conf.setConfString(key, " 20") + assert(conf.getConf(confEntry, 5) === 20) + val e = intercept[IllegalArgumentException] { conf.setConfString(key, "abc") } @@ -75,6 +78,8 @@ class SQLConfEntrySuite extends SparkFunSuite { assert(conf.getConfString(key) === "true") assert(conf.getConf(confEntry, false) === true) + conf.setConfString(key, " true ") + assert(conf.getConf(confEntry, false) === true) val e = intercept[IllegalArgumentException] { conf.setConfString(key, "abc") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index a283ff971adc..205c303b6cc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.internal @@ -270,4 +270,15 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { val e2 = intercept[AnalysisException](spark.conf.unset(SCHEMA_STRING_LENGTH_THRESHOLD.key)) assert(e2.message.contains("Cannot modify the value of a static config")) } + + test("SPARK-21588 SQLContext.getConf(key, null) should return null") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + assert("1" == spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key, null)) + assert("1" == spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key, "")) + } + + assert(spark.conf.getOption("spark.sql.nonexistent").isEmpty) + assert(null == spark.conf.get("spark.sql.nonexistent", null)) + assert("" == spark.conf.get("spark.sql.nonexistent", "")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5bd36ec25ccb..34205e0b2bf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,14 +18,15 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.{Date, DriverManager, SQLException, Timestamp} import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand @@ -96,6 +97,15 @@ class JDBCSuite extends SparkFunSuite | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') """.stripMargin.replaceAll("\n", " ")) + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW partsoverflow + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '-9223372036854775808', + | upperBound '9223372036854775807', numPartitions '3') + """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, " + "d SMALLINT, e BIGINT)").executeUpdate() conn.prepareStatement("insert into test.inttypes values (1, false, 3, 4, 1234567890123)" @@ -141,6 +151,15 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.timezone (tz TIMESTAMP WITH TIME ZONE) " + + "AS SELECT '1999-01-08 04:05:06.543543543 GMT-08:00'") + .executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE test.array (ar ARRAY) " + + "AS SELECT '(1, 2, 3)'") + .executeUpdate() + conn.commit() conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() @@ -229,7 +248,7 @@ class JDBCSuite extends SparkFunSuite // Check whether the tables are fetched in the expected degree of parallelism def checkNumPartitions(df: DataFrame, expectedNumPartitions: Int): Unit = { val jdbcRelations = df.queryExecution.analyzed.collect { - case LogicalRelation(r: JDBCRelation, _, _) => r + case LogicalRelation(r: JDBCRelation, _, _, _) => r } assert(jdbcRelations.length == 1) assert(jdbcRelations.head.parts.length == expectedNumPartitions, @@ -367,6 +386,12 @@ class JDBCSuite extends SparkFunSuite assert(ids(2) === 3) } + test("overflow of partition bound difference does not give negative stride") { + val df = sql("SELECT * FROM partsoverflow") + checkNumPartitions(df, expectedNumPartitions = 3) + assert(df.collect().length == 3) + } + test("Register JDBC query with renamed fields") { // Regression test for bug SPARK-7345 sql( @@ -397,6 +422,28 @@ class JDBCSuite extends SparkFunSuite assert(e.contains("Invalid value `-1` for parameter `fetchsize`")) } + test("Missing partition columns") { + withView("tempPeople") { + val e = intercept[IllegalArgumentException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW tempPeople + |USING org.apache.spark.sql.jdbc + |OPTIONS ( + | url 'jdbc:h2:mem:testdb0;user=testUser;password=testPass', + | dbtable 'TEST.PEOPLE', + | lowerBound '0', + | upperBound '52', + | numPartitions '53', + | fetchSize '10000' ) + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e.contains("When reading JDBC data sources, users need to specify all or none " + + "for the following options: 'partitionColumn', 'lowerBound', 'upperBound', and " + + "'numPartitions'")) + } + } + test("Basic API with FetchSize") { (0 to 4).foreach { size => val properties = new Properties() @@ -693,17 +740,53 @@ class JDBCSuite extends SparkFunSuite } else { None } + override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + assert(agg.isCascadingTruncateTable() === Some(true)) + } + + test("Aggregated dialects: isCascadingTruncateTable") { + def genDialect(cascadingTruncateTable: Option[Boolean]): JdbcDialect = new JdbcDialect { + override def canHandle(url: String): Boolean = true + override def getCatalystType( + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = None + override def isCascadingTruncateTable(): Option[Boolean] = cascadingTruncateTable + } + + def testDialects(cascadings: List[Option[Boolean]], expected: Option[Boolean]): Unit = { + val dialects = cascadings.map(genDialect(_)) + val agg = new AggregatedDialect(dialects) + assert(agg.isCascadingTruncateTable() === expected) + } + + testDialects(List(Some(true), Some(false), None), Some(true)) + testDialects(List(Some(true), Some(true), None), Some(true)) + testDialects(List(Some(false), Some(false), None), None) + testDialects(List(Some(true), Some(true)), Some(true)) + testDialects(List(Some(false), Some(false)), Some(false)) + testDialects(List(None, None), None) } test("DB2Dialect type mapping") { val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + assert(db2Dialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == "SMALLINT") + assert(db2Dialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") + // test db2 dialect mappings on read + assert(db2Dialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, null) == Option(FloatType)) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "DECFLOAT", 1, null) == + Option(DecimalType(38, 18))) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "XML", 1, null) == Option(StringType)) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "TIMESTAMP WITH TIME ZONE", 1, null) == + Option(TimestampType)) } test("PostgresDialect type mapping") { @@ -913,12 +996,63 @@ class JDBCSuite extends SparkFunSuite assert(e2.contains("User specified schema not supported with `jdbc`")) } + test("jdbc API support custom schema") { + val parts = Array[String]("THEID < 2", "THEID >= 2") + val customSchema = "NAME STRING, THEID INT" + val props = new Properties() + props.put("customSchema", customSchema) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props) + assert(df.schema.size === 2) + assert(df.schema === CatalystSqlParser.parseTableSchema(customSchema)) + assert(df.count() === 3) + } + + test("jdbc API custom schema DDL-like strings.") { + withTempView("people_view") { + val customSchema = "NAME STRING, THEID INT" + sql( + s""" + |CREATE TEMPORARY VIEW people_view + |USING org.apache.spark.sql.jdbc + |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass', + |customSchema '$customSchema') + """.stripMargin.replaceAll("\n", " ")) + val df = sql("select * from people_view") + assert(df.schema.length === 2) + assert(df.schema === CatalystSqlParser.parseTableSchema(customSchema)) + assert(df.count() === 3) + } + } + + test("SPARK-15648: teradataDialect StringType data mapping") { + val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + assert(teradataDialect.getJDBCType(StringType). + map(_.databaseTypeDefinition).get == "VARCHAR(255)") + } + + test("SPARK-15648: teradataDialect BooleanType data mapping") { + val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + assert(teradataDialect.getJDBCType(BooleanType). + map(_.databaseTypeDefinition).get == "CHAR(1)") + } + test("Checking metrics correctness with JDBC") { val foobarCnt = spark.table("foobar").count() val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF()) assert(res === (foobarCnt, 0L, foobarCnt) :: Nil) } + test("unsupported types") { + var e = intercept[SparkException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.TIMEZONE", new Properties()).collect() + }.getMessage + assert(e.contains("java.lang.UnsupportedOperationException: unimplemented")) + e = intercept[SQLException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY", new Properties()).collect() + }.getMessage + assert(e.contains("Unsupported type ARRAY")) + } + test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive @@ -966,4 +1100,35 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").count() == 3) } } + + test("SPARK-21519: option sessionInitStatement, run SQL to initialize the database session.") { + val initSQL1 = "SET @MYTESTVAR 21519" + val df1 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "(SELECT NVL(@MYTESTVAR, -1))") + .option("sessionInitStatement", initSQL1) + .load() + assert(df1.collect() === Array(Row(21519))) + + val initSQL2 = "SET SCHEMA DUMMY" + val df2 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PEOPLE") + .option("sessionInitStatement", initSQL2) + .load() + val e = intercept[SparkException] {df2.collect()}.getMessage + assert(e.contains("""Schema "DUMMY" not found""")) + + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW test_sessionInitStatement + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$urlWithUserAndPass', + |dbtable '(SELECT NVL(@MYTESTVAR1, -1), NVL(@MYTESTVAR2, -1))', + |sessionInitStatement 'SET @MYTESTVAR1 21519; SET @MYTESTVAR2 1234') + """.stripMargin) + + val df3 = sql("SELECT * FROM test_sessionInitStatement") + assert(df3.collect() === Array(Row(21519, 1234))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index bf1fd160704f..1985b1dc8287 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.jdbc -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.DriverManager import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} @@ -323,8 +324,9 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .option("partitionColumn", "foo") .save() }.getMessage - assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + - " and 'numPartitions' are required.")) + assert(e.contains("When reading JDBC data sources, users need to specify all or none " + + "for the following options: 'partitionColumn', 'lowerBound', 'upperBound', and " + + "'numPartitions'")) } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { @@ -466,7 +468,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column .jdbc(url1, "TEST.USERDBTYPETEST", properties) }.getMessage() - assert(msg.contains("no viable alternative at input")) + assert(msg.contains("extraneous input")) } test("SPARK-10849: jdbc CreateTableColumnTypes duplicate columns") { @@ -478,7 +480,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .jdbc(url1, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains( - "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe")) + "Found duplicate column(s) in the createTableColumnTypes option value: `name`")) } } @@ -506,4 +508,11 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { "schema struct")) } } + + test("SPARK-19726: INSERT null to a NOT NULL column") { + val e = intercept[SparkException] { + sql("INSERT INTO PEOPLE1 values (null, null)") + }.getMessage + assert(e.contains("NULL not allowed for column \"NAME\"")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index ba0ca666b5c1..ab18905e2ddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -302,10 +302,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { // check existence of shuffle assert( - joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft, + joinOperator.left.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleLeft, s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") assert( - joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight, + joinOperator.right.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleRight, s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") // check existence of sort @@ -506,7 +506,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } @@ -520,7 +520,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } @@ -543,6 +543,65 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-19122 Re-order join predicates if they match with the child's output partitioning") { + val bucketedTableTestSpec = BucketedTableTestSpec( + Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))), + numPartitions = 1, + expectedShuffle = false, + expectedSort = false) + + // If the set of join columns is equal to the set of bucketed + sort columns, then + // the order of join keys in the query should not matter and there should not be any shuffle + // and sort added in the query plan + Seq( + Seq("i", "j", "k"), + Seq("i", "k", "j"), + Seq("j", "k", "i"), + Seq("j", "i", "k"), + Seq("k", "j", "i"), + Seq("k", "i", "j") + ).foreach(joinKeys => { + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec, + bucketedTableTestSpecRight = bucketedTableTestSpec, + joinCondition = joinCondition(joinKeys) + ) + }) + } + + test("SPARK-19122 No re-ordering should happen if set of join columns != set of child's " + + "partitioning columns") { + + // join predicates is a super set of child's partitioning columns + val bucketedTableTestSpec1 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec1, + bucketedTableTestSpecRight = bucketedTableTestSpec1, + joinCondition = joinCondition(Seq("i", "j", "k")) + ) + + // child's partitioning columns is a super set of join predicates + val bucketedTableTestSpec2 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))), + numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec2, + bucketedTableTestSpecRight = bucketedTableTestSpec2, + joinCondition = joinCondition(Seq("i", "j")) + ) + + // set of child's partitioning columns != set join predicates (despite the lengths of the + // sets are same) + val bucketedTableTestSpec3 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec3, + bucketedTableTestSpecRight = bucketedTableTestSpec3, + joinCondition = joinCondition(Seq("j", "k")) + ) + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 85ba33e58a78..3ce6ae3c5292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -1,44 +1,57 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.sources import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types._ // please note that the META-INF/services had to be modified for the test directory for this to work class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { - test("data sources with the same name") { - intercept[RuntimeException] { + test("data sources with the same name - internal data sources") { + val e = intercept[AnalysisException] { spark.read.format("Fluet da Bomb").load() } + assert(e.getMessage.contains("Multiple sources found for Fluet da Bomb")) + } + + test("data sources with the same name - internal data source/external data source") { + assert(spark.read.format("datasource").load().schema == + StructType(Seq(StructField("longType", LongType, nullable = false)))) + } + + test("data sources with the same name - external data sources") { + val e = intercept[AnalysisException] { + spark.read.format("Fake external source").load() + } + assert(e.getMessage.contains("Multiple sources found for Fake external source")) } test("load data source from format alias") { - spark.read.format("gathering quorum").load().schema == - StructType(Seq(StructField("stringType", StringType, nullable = false))) + assert(spark.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false)))) } test("specify full classname with duplicate formats") { - spark.read.format("org.apache.spark.sql.sources.FakeSourceOne") - .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))) } test("should fail to load ORC without Hive Support") { @@ -63,7 +76,7 @@ class FakeSourceOne extends RelationProvider with DataSourceRegister { } } -class FakeSourceTwo extends RelationProvider with DataSourceRegister { +class FakeSourceTwo extends RelationProvider with DataSourceRegister { def shortName(): String = "Fluet da Bomb" @@ -72,7 +85,7 @@ class FakeSourceTwo extends RelationProvider with DataSourceRegister { override def sqlContext: SQLContext = cont override def schema: StructType = - StructType(Seq(StructField("stringType", StringType, nullable = false))) + StructType(Seq(StructField("integerType", IntegerType, nullable = false))) } } @@ -88,3 +101,16 @@ class FakeSourceThree extends RelationProvider with DataSourceRegister { StructType(Seq(StructField("stringType", StringType, nullable = false))) } } + +class FakeSourceFour extends RelationProvider with DataSourceRegister { + + def shortName(): String = "datasource" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("longType", LongType, nullable = false))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 80868fff897f..1ece98aa7eb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -1,33 +1,36 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String private[sql] abstract class DataSourceTest extends QueryTest { - protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row], enableRegex: Boolean = false) { test(sqlString) { - checkAnswer(spark.sql(sqlString), expectedAnswer) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> enableRegex.toString) { + checkAnswer(spark.sql(sqlString), expectedAnswer) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 5a0388ec1d1d..c45b507d2b48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.sources @@ -327,7 +327,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic val table = spark.table("oneToTenFiltered") val relation = table.queryExecution.logical.collectFirst { - case LogicalRelation(r, _, _) => r + case LogicalRelation(r, _, _, _) => r }.get assert( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 2eae66dda88d..875b74551add 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import java.io.File +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -345,4 +346,84 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } } + + test("SPARK-21203 wrong results of insertion of Array of Struct") { + val tabName = "tab1" + withTable(tabName) { + spark.sql( + """ + |CREATE TABLE `tab1` + |(`custom_fields` ARRAY>) + |USING parquet + """.stripMargin) + spark.sql( + """ + |INSERT INTO `tab1` + |SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b')) + """.stripMargin) + + checkAnswer( + spark.sql("SELECT custom_fields.id, custom_fields.value FROM tab1"), + Row(Array(1, 2), Array("a", "b"))) + } + } + + test("insert overwrite directory") { + withTempDir { dir => + val path = dir.toURI.getPath + + val v1 = + s""" + | INSERT OVERWRITE DIRECTORY '${path}' + | USING json + | OPTIONS (a 1, b 0.1, c TRUE) + | SELECT 1 as a, 'c' as b + """.stripMargin + + spark.sql(v1) + + checkAnswer( + spark.read.json(dir.getCanonicalPath), + sql("SELECT 1 as a, 'c' as b")) + } + } + + test("insert overwrite directory with path in options") { + withTempDir { dir => + val path = dir.toURI.getPath + + val v1 = + s""" + | INSERT OVERWRITE DIRECTORY + | USING json + | OPTIONS ('path' '${path}') + | SELECT 1 as a, 'c' as b + """.stripMargin + + spark.sql(v1) + + checkAnswer( + spark.read.json(dir.getCanonicalPath), + sql("SELECT 1 as a, 'c' as b")) + } + } + + test("insert overwrite directory to data source not providing FileFormat") { + withTempDir { dir => + val path = dir.toURI.getPath + + val v1 = + s""" + | INSERT OVERWRITE DIRECTORY '${path}' + | USING JDBC + | OPTIONS (a 1, b 0.1, c TRUE) + | SELECT 1 as a, 'c' as b + """.stripMargin + val e = intercept[SparkException] { + spark.sql(v1) + }.getMessage + + assert(e.contains("Only Data Sources providing FileFormat are supported")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a2f3afe3ce23..0fe33e87318a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -32,18 +32,13 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -private class OnlyDetectCustomPathFileCommitProtocol(jobId: String, path: String, isAppend: Boolean) - extends SQLHadoopMapReduceCommitProtocol(jobId, path, isAppend) +private class OnlyDetectCustomPathFileCommitProtocol(jobId: String, path: String) + extends SQLHadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { override def newTaskTempFileAbsPath( taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { - if (isAppend) { - throw new Exception("append data to an existed partitioned table, " + - "there should be no custom partition path sent to Task") - } - - super.newTaskTempFileAbsPath(taskContext, absoluteDir, ext) + throw new Exception("there should be no custom partition path") } } @@ -91,15 +86,15 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempDir { f => spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) } } @@ -111,11 +106,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { .option("maxRecordsPerFile", 1) .mode("overwrite") .parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) } } - test("append data to an existed partitioned table without custom partition path") { + test("append data to an existing partitioned table without custom partition path") { withTable("t") { withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> classOf[OnlyDetectCustomPathFileCommitProtocol].getName) { @@ -138,14 +133,14 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val df = Seq((1, ts)).toDF("i", "ts") withTempPath { f => df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. checkPartitionValues(files.head, "2016-12-01 08:00:00") @@ -153,18 +148,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // if there isn't timeZone option, then use session local timezone. checkPartitionValues(files.head, "2016-12-01 08:00:00") } } } - - /** Lists files recursively. */ - private def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val current = f.listFiles - current ++ current.filter(_.isDirectory).flatMap(recursiveList) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index 6dd4847ead73..85da3f0e3846 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.sources @@ -92,12 +92,12 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { s""" |CREATE TABLE src |USING ${classOf[TestOptionsSource].getCanonicalName} - |OPTIONS (PATH '$p') + |OPTIONS (PATH '${p.toURI}') |AS SELECT 1 """.stripMargin) assert( spark.table("src").schema.head.metadata.getString("path") == - p.getAbsolutePath) + p.toURI.toString) } } @@ -135,7 +135,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { private def getPathOption(tableName: String): Option[String] = { spark.table(tableName).queryExecution.analyzed.collect { - case LogicalRelation(r: TestOptionsRelation, _, _) => r.pathOption + case LogicalRelation(r: TestOptionsRelation, _, _, _) => r.pathOption }.head } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index fb6123d1cc4b..c1eaf948a4b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.sources diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 0f97fd78d2ff..4adbff5c663b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.sources @@ -21,11 +21,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.test.SharedSQLContext -class ResolvedDataSourceSuite extends SparkFunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { private def getProvidingClass(name: String): Class[_] = DataSource( - sparkSession = null, + sparkSession = spark, className = name, options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID) ).providingClass diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index b01d15eb917e..17690e3df915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.sources @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -247,32 +248,34 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { assert(expectedSchema == spark.table("tableWithSchema").schema) - checkAnswer( - sql( - """SELECT - | `string$%Field`, - | cast(binaryField as string), - | booleanField, - | byteField, - | shortField, - | int_Field, - | `longField_:,<>=+/~^`, - | floatField, - | doubleField, - | decimalField1, - | decimalField2, - | dateField, - | timestampField, - | varcharField, - | charField, - | arrayFieldSimple, - | arrayFieldComplex, - | mapFieldSimple, - | mapFieldComplex, - | structFieldSimple, - | structFieldComplex FROM tableWithSchema""".stripMargin), - tableWithSchemaExpected - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer( + sql( + """SELECT + | `string$%Field`, + | cast(binaryField as string), + | booleanField, + | byteField, + | shortField, + | int_Field, + | `longField_:,<>=+/~^`, + | floatField, + | doubleField, + | decimalField1, + | decimalField2, + | dateField, + | timestampField, + | varcharField, + | charField, + | arrayFieldSimple, + | arrayFieldComplex, + | mapFieldSimple, + | mapFieldComplex, + | structFieldSimple, + | structFieldComplex FROM tableWithSchema""".stripMargin), + tableWithSchemaExpected + ) + } } sqlTest( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala new file mode 100644 index 000000000000..bf43de597a7a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fakesource + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} +import org.apache.spark.sql.types._ + + +// Note that the package name is intendedly mismatched in order to resemble external data sources +// and test the detection for them. +class FakeExternalSourceOne extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fake external source" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeExternalSourceTwo extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fake external source" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("integerType", IntegerType, nullable = false))) + } +} + +class FakeExternalSourceThree extends RelationProvider with DataSourceRegister { + + def shortName(): String = "datasource" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("byteType", ByteType, nullable = false))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala new file mode 100644 index 000000000000..933f4075bcc8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite + +/** + * A simple test suite to verify `DataSourceV2Options`. + */ +class DataSourceV2OptionsSuite extends SparkFunSuite { + + test("key is case-insensitive") { + val options = new DataSourceV2Options(Map("foo" -> "bar").asJava) + assert(options.get("foo").get() == "bar") + assert(options.get("FoO").get() == "bar") + assert(!options.get("abc").isPresent) + } + + test("value is case-sensitive") { + val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) + assert(options.get("foo").get == "bAr") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala new file mode 100644 index 000000000000..9ce93d7ae926 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util.{ArrayList, List => JList} + +import test.org.apache.spark.sql.sources.v2._ + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.{Filter, GreaterThan} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class DataSourceV2Suite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("simplest implementation") { + Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) + } + } + } + + test("advanced implementation") { + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i))) + checkAnswer(df.select('i).filter('i > 10), Nil) + } + } + } + + test("unsafe row implementation") { + Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) + } + } + } + + test("schema required data source") { + Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => + withClue(cls.getName) { + val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) + assert(e.message.contains("A schema needs to be specified")) + + val schema = new StructType().add("i", "int").add("s", "string") + val df = spark.read.format(cls.getName).schema(schema).load() + + assert(df.schema == schema) + assert(df.collect().isEmpty) + } + } + } +} + +class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createReadTasks(): JList[ReadTask[Row]] = { + java.util.Arrays.asList(new SimpleReadTask(0, 5), new SimpleReadTask(5, 10)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { + private var current = start - 1 + + override def createReader(): DataReader[Row] = new SimpleReadTask(start, end) + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): Row = Row(current, -current) + + override def close(): Unit = {} +} + + + +class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { + + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + this.filters = filters + Array.empty + } + + override def readSchema(): StructType = { + requiredSchema + } + + override def createReadTasks(): JList[ReadTask[Row]] = { + val lowerBound = filters.collect { + case GreaterThan("i", v: Int) => v + }.headOption + + val res = new ArrayList[ReadTask[Row]] + + if (lowerBound.isEmpty) { + res.add(new AdvancedReadTask(0, 5, requiredSchema)) + res.add(new AdvancedReadTask(5, 10, requiredSchema)) + } else if (lowerBound.get < 4) { + res.add(new AdvancedReadTask(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedReadTask(5, 10, requiredSchema)) + } else if (lowerBound.get < 9) { + res.add(new AdvancedReadTask(lowerBound.get + 1, 10, requiredSchema)) + } + + res + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) + extends ReadTask[Row] with DataReader[Row] { + + private var current = start - 1 + + override def createReader(): DataReader[Row] = new AdvancedReadTask(start, end, requiredSchema) + + override def close(): Unit = {} + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): Row = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current + } + Row.fromSeq(values) + } +} + + +class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createUnsafeRowReadTasks(): JList[ReadTask[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowReadTask(0, 5), new UnsafeRowReadTask(5, 10)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class UnsafeRowReadTask(start: Int, end: Int) + extends ReadTask[UnsafeRow] with DataReader[UnsafeRow] { + + private val row = new UnsafeRow(2) + row.pointTo(new Array[Byte](8 * 3), 8 * 3) + + private var current = start - 1 + + override def createReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + + override def next(): Boolean = { + current += 1 + current < end + } + override def get(): UnsafeRow = { + row.setInt(0, current) + row.setInt(1, -current) + row + } + + override def close(): Unit = {} +} + +class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { + + class Reader(val readSchema: StructType) extends DataSourceV2Reader { + override def createReadTasks(): JList[ReadTask[Row]] = + java.util.Collections.emptyList() + } + + override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = + new Reader(schema) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index a15c2cff930f..e858b7d9998a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -268,4 +268,17 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { CheckLastBatch(7) ) } + + test("SPARK-21546: dropDuplicates should ignore watermark when it's not a key") { + val input = MemoryStream[(Int, Int)] + val df = input.toDS.toDF("id", "time") + .withColumn("time", $"time".cast("timestamp")) + .withWatermark("time", "1 second") + .dropDuplicates("id") + .select($"id", $"time".cast("long")) + testStream(df)( + AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), + CheckLastBatch(1 -> 1, 2 -> 2) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala new file mode 100644 index 000000000000..ed9823fbddfd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.UUID + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} +import org.apache.spark.sql.test.SharedSQLContext + +class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { + + import testImplicits._ + + private var baseDf: DataFrame = null + + override def beforeAll(): Unit = { + super.beforeAll() + baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") + } + + test("ClusteredDistribution generates Exchange with HashPartitioning") { + testEnsureStatefulOpPartitioning( + baseDf.queryExecution.sparkPlan, + requiredDistribution = keys => ClusteredDistribution(keys), + expectedPartitioning = + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + expectShuffle = true) + } + + test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") { + testEnsureStatefulOpPartitioning( + baseDf.coalesce(1).queryExecution.sparkPlan, + requiredDistribution = keys => ClusteredDistribution(keys), + expectedPartitioning = + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + expectShuffle = true) + } + + test("AllTuples generates Exchange with SinglePartition") { + testEnsureStatefulOpPartitioning( + baseDf.queryExecution.sparkPlan, + requiredDistribution = _ => AllTuples, + expectedPartitioning = _ => SinglePartition, + expectShuffle = true) + } + + test("AllTuples with coalesce(1) doesn't need Exchange") { + testEnsureStatefulOpPartitioning( + baseDf.coalesce(1).queryExecution.sparkPlan, + requiredDistribution = _ => AllTuples, + expectedPartitioning = _ => SinglePartition, + expectShuffle = false) + } + + /** + * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan + * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to + * ensure the expected partitioning. + */ + private def testEnsureStatefulOpPartitioning( + inputPlan: SparkPlan, + requiredDistribution: Seq[Attribute] => Distribution, + expectedPartitioning: Seq[Attribute] => Partitioning, + expectShuffle: Boolean): Unit = { + val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) + val executed = executePlan(operator, OutputMode.Complete()) + if (expectShuffle) { + val exchange = executed.children.find(_.isInstanceOf[Exchange]) + if (exchange.isEmpty) { + fail(s"Was expecting an exchange but didn't get one in:\n$executed") + } + assert(exchange.get === + ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan), + s"Exchange didn't have expected properties:\n${exchange.get}") + } else { + assert(!executed.children.exists(_.isInstanceOf[Exchange]), + s"Unexpected exchange found in:\n$executed") + } + } + + /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ + private def executePlan( + p: SparkPlan, + outputMode: OutputMode = OutputMode.Append()): SparkPlan = { + val execution = new IncrementalExecution( + spark, + null, + OutputMode.Complete(), + "chk", + UUID.randomUUID(), + 0L, + OffsetSeqMetadata()) { + override lazy val sparkPlan: SparkPlan = p transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + execution.executedPlan + } +} + +/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ +case class TestStatefulOperator( + child: SparkPlan, + requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { + override def output: Seq[Attribute] = child.output + override def doExecute(): RDD[InternalRow] = child.execute() + override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil + override def stateInfo: Option[StatefulOperatorStateInfo] = None +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index fd850a7365e2..f3e8cf950a5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -21,7 +21,7 @@ import java.{util => ju} import java.text.SimpleDateFormat import java.util.Date -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.OutputMode._ -class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Logging { +class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { import testImplicits._ @@ -38,6 +38,43 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin sqlContext.streams.active.foreach(_.stop()) } + test("EventTimeStats") { + val epsilon = 10E-6 + + val stats = EventTimeStats(max = 100, min = 10, avg = 20.0, count = 5) + stats.add(80L) + stats.max should be (100) + stats.min should be (10) + stats.avg should be (30.0 +- epsilon) + stats.count should be (6) + + val stats2 = EventTimeStats(80L, 5L, 15.0, 4) + stats.merge(stats2) + stats.max should be (100) + stats.min should be (5) + stats.avg should be (24.0 +- epsilon) + stats.count should be (10) + } + + test("EventTimeStats: avg on large values") { + val epsilon = 10E-6 + val largeValue = 10000000000L // 10B + // Make sure `largeValue` will cause overflow if we use a Long sum to calc avg. + assert(largeValue * largeValue != BigInt(largeValue) * BigInt(largeValue)) + val stats = + EventTimeStats(max = largeValue, min = largeValue, avg = largeValue, count = largeValue - 1) + stats.add(largeValue) + stats.avg should be (largeValue.toDouble +- epsilon) + + val stats2 = EventTimeStats( + max = largeValue + 1, + min = largeValue, + avg = largeValue + 1, + count = largeValue) + stats.merge(stats2) + stats.avg should be ((largeValue + 0.5) +- epsilon) + } + test("error on bad column") { val inputData = MemoryStream[Int].toDF() val e = intercept[AnalysisException] { @@ -263,6 +300,84 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin ) } + test("watermark with 2 streams") { + import org.apache.spark.sql.functions.sum + val first = MemoryStream[Int] + + val firstDf = first.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .select('value) + + val second = MemoryStream[Int] + + val secondDf = second.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "5 seconds") + .select('value) + + withTempDir { checkpointDir => + val unionWriter = firstDf.union(secondDf).agg(sum('value)) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("memory") + .outputMode("complete") + .queryName("test") + + val union = unionWriter.start() + + def getWatermarkAfterData( + firstData: Seq[Int] = Seq.empty, + secondData: Seq[Int] = Seq.empty, + query: StreamingQuery = union): Long = { + if (firstData.nonEmpty) first.addData(firstData) + if (secondData.nonEmpty) second.addData(secondData) + query.processAllAvailable() + // add a dummy batch so lastExecution has the new watermark + first.addData(0) + query.processAllAvailable() + // get last watermark + val lastExecution = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution + lastExecution.offsetSeqMetadata.batchWatermarkMs + } + + // Global watermark starts at 0 until we get data from both sides + assert(getWatermarkAfterData(firstData = Seq(11)) == 0) + assert(getWatermarkAfterData(secondData = Seq(6)) == 1000) + // Global watermark stays at left watermark 1 when right watermark moves to 2 + assert(getWatermarkAfterData(secondData = Seq(8)) == 1000) + // Global watermark switches to right side value 2 when left watermark goes higher + assert(getWatermarkAfterData(firstData = Seq(21)) == 3000) + // Global watermark goes back to left + assert(getWatermarkAfterData(secondData = Seq(17, 28, 39)) == 11000) + // Global watermark stays on left as long as it's below right + assert(getWatermarkAfterData(firstData = Seq(31)) == 21000) + assert(getWatermarkAfterData(firstData = Seq(41)) == 31000) + // Global watermark switches back to right again + assert(getWatermarkAfterData(firstData = Seq(51)) == 34000) + + // Global watermark is updated correctly with simultaneous data from both sides + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(100)) == 90000) + assert(getWatermarkAfterData(firstData = Seq(120), secondData = Seq(110)) == 105000) + assert(getWatermarkAfterData(firstData = Seq(130), secondData = Seq(125)) == 120000) + + // Global watermark doesn't decrement with simultaneous data + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(100)) == 120000) + assert(getWatermarkAfterData(firstData = Seq(140), secondData = Seq(100)) == 120000) + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(135)) == 130000) + + // Global watermark recovers after restart, but left side watermark ahead of it does not. + assert(getWatermarkAfterData(firstData = Seq(200), secondData = Seq(190)) == 185000) + union.stop() + val union2 = unionWriter.start() + assert(getWatermarkAfterData(query = union2) == 185000) + // Even though the left side was ahead of 185000 in the last execution, the watermark won't + // increment until it gets past it in this execution. + assert(getWatermarkAfterData(secondData = Seq(200), query = union2) == 185000) + assert(getWatermarkAfterData(firstData = Seq(200), query = union2) == 190000) + } + } + test("complete mode") { val inputData = MemoryStream[Int] @@ -344,6 +459,44 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin assert(eventTimeColumns(0).name === "second") } + test("EventTime watermark should be ignored in batch query.") { + val df = testData + .withColumn("eventTime", $"key".cast("timestamp")) + .withWatermark("eventTime", "1 minute") + .select("eventTime") + .as[Long] + + checkDataset[Long](df, 1L to 100L: _*) + } + + test("SPARK-21565: watermark operator accepts attributes from replacement") { + withTempDir { dir => + dir.delete() + + val df = Seq(("a", 100.0, new java.sql.Timestamp(100L))) + .toDF("symbol", "price", "eventTime") + df.write.json(dir.getCanonicalPath) + + val input = spark.readStream.schema(df.schema) + .json(dir.getCanonicalPath) + + val groupEvents = input + .withWatermark("eventTime", "2 seconds") + .groupBy("symbol", "eventTime") + .agg(count("price") as 'count) + .select("symbol", "eventTime", "count") + val q = groupEvents.writeStream + .outputMode("append") + .format("console") + .start() + try { + q.processAllAvailable() + } finally { + q.stop() + } + } + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 1211242b9fbb..08db06b94904 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -19,11 +19,14 @@ package org.apache.spark.sql.streaming import java.util.Locale +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -62,6 +65,35 @@ class FileStreamSinkSuite extends StreamTest { } } + test("SPARK-21167: encode and decode path correctly") { + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + val query = ds.map(s => (s, s.length)) + .toDF("value", "len") + .writeStream + .partitionBy("value") + .option("checkpointLocation", checkpointDir) + .format("parquet") + .start(outputDir) + + try { + // The output is partitoned by "value", so the value will appear in the file path. + // This is to test if we handle spaces in the path correctly. + inputData.addData("hello world") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + val outputDf = spark.read.parquet(outputDir) + checkDatasetUnorderly(outputDf.as[(Int, String)], ("hello world".length, "hello world")) + } finally { + query.stop() + } + } + test("partitioned writing and batch reading") { val inputData = MemoryStream[Int] val ds = inputData.toDS() @@ -95,8 +127,7 @@ class FileStreamSinkSuite extends StreamTest { // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect { - case LogicalRelation(baseRelation, _, _) if baseRelation.isInstanceOf[HadoopFsRelation] => - baseRelation.asInstanceOf[HadoopFsRelation] + case LogicalRelation(baseRelation: HadoopFsRelation, _, _, _) => baseRelation } assert(hadoopdFsRelations.size === 1) assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex]) @@ -145,6 +176,43 @@ class FileStreamSinkSuite extends StreamTest { } } + test("partitioned writing and batch reading with 'basePath'") { + withTempDir { outputDir => + withTempDir { checkpointDir => + val outputPath = outputDir.getAbsolutePath + val inputData = MemoryStream[Int] + val ds = inputData.toDS() + + var query: StreamingQuery = null + + try { + query = + ds.map(i => (i, -i, i * 1000)) + .toDF("id1", "id2", "value") + .writeStream + .partitionBy("id1", "id2") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("parquet") + .start(outputPath) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + val readIn = spark.read.option("basePath", outputPath).parquet(s"$outputDir/*/*") + checkDatasetUnorderly( + readIn.as[(Int, Int, Int)], + (1000, 1, -1), (2000, 2, -2), (3000, 3, -3)) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + // This tests whether FileStreamSink works with aggregations. Specifically, it tests // whether the correct streaming QueryExecution (i.e. IncrementalExecution) is used to // to execute the trigger for writing data to file sink. See SPARK-18440 for more details. @@ -266,4 +334,58 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("FileStreamSink.ancestorIsMetadataDirectory()") { + val hadoopConf = spark.sparkContext.hadoopConfiguration + def assertAncestorIsMetadataDirectory(path: String): Unit = + assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + def assertAncestorIsNotMetadataDirectory(path: String): Unit = + assert(!FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c/") + + assertAncestorIsNotMetadataDirectory(s"/a/b/c") + assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") + } + + test("SPARK-20460 Check name duplication in schema") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + try { + query = + df.writeStream + .option("checkpointLocation", checkpointDir) + .format("json") + .start(outputDir) + + inputData.addData((1, 1)) + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + if (query != null) { + query.stop() + } + } + + val errorMsg = intercept[AnalysisException] { + spark.read.schema(s"$c0 INT, $c1 INT").json(outputDir).as[(Int, Int)] + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema: ")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 2108b118bf05..b6baaed1927e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1105,7 +1105,10 @@ class FileStreamSourceSuite extends FileStreamSourceTest { def verify(startId: Option[Int], endId: Int, expected: String*): Unit = { val start = startId.map(new FileStreamSourceOffset(_)) val end = FileStreamSourceOffset(endId) - assert(fileSource.getBatch(start, end).as[String].collect().toSeq === expected) + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + assert(fileSource.getBatch(start, end).as[String].collect().toSeq === expected) + } } verify(startId = None, endId = 2, "keep1", "keep2", "keep3") @@ -1314,6 +1317,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val metadataLog = new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) + assert(metadataLog.add(1, Array(FileEntry(s"$scheme:///file2", 200L, 0)))) val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil, dir.getAbsolutePath, Map.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 85aa7dbe9ed8..9d74a5c701ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -73,14 +73,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.hasRemoved === shouldBeRemoved) } + // === Tests for state in streaming queries === // Updating empty state - state = new GroupStateImpl[String](None) + state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = new GroupStateImpl[String](Some("2")) + state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -98,25 +99,35 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - test("GroupState - setTimeout**** with NoTimeout") { - for (initState <- Seq(None, Some(5))) { - // for different initial state - implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + test("GroupState - setTimeout - with NoTimeout") { + for (initValue <- Seq(None, Some(5))) { + val states = Seq( + GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false), + GroupStateImpl.createForBatch(NoTimeout) + ) + for (state <- states) { + // for streaming queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } } } - test("GroupState - setTimeout**** with ProcessingTimeTimeout") { - implicit var state: GroupStateImpl[Int] = null - - state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + test("GroupState - setTimeout - with ProcessingTimeTimeout") { + // for streaming queries + var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( + None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) - testTimeoutDurationNotAllowed[IllegalStateException](state) + state.setTimeoutDuration(500) + assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.update(5) - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + assert(state.getTimeoutTimestamp === 1500) // does not change state.setTimeoutDuration(1000) assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration("2 second") @@ -124,19 +135,38 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.remove() + assert(state.getTimeoutTimestamp === 3000) // does not change + state.setTimeoutDuration(500) // can still be set + assert(state.getTimeoutTimestamp === 1500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) - testTimeoutDurationNotAllowed[IllegalStateException](state) + state.setTimeoutDuration(500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.update(5) + state.setTimeoutDuration(1000) + state.setTimeoutDuration("2 second") + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutDuration(500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } - test("GroupState - setTimeout**** with EventTimeTimeout") { - implicit val state = new GroupStateImpl[Int]( - None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) + test("GroupState - setTimeout - with EventTimeTimeout") { + var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( + None, 1000, 1000, EventTimeTimeout, false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[IllegalStateException](state) + state.setTimeoutTimestamp(5000) + assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state state.update(5) + assert(state.getTimeoutTimestamp === 5000) // does not change state.setTimeoutTimestamp(10000) assert(state.getTimeoutTimestamp === 10000) state.setTimeoutTimestamp(new Date(20000)) @@ -144,57 +174,112 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.remove() + assert(state.getTimeoutTimestamp === 20000) + state.setTimeoutTimestamp(5000) + assert(state.getTimeoutTimestamp === 5000) // can be set after removing state + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[IllegalStateException](state) + state.setTimeoutTimestamp(5000) + + state.update(5) + state.setTimeoutTimestamp(10000) + state.setTimeoutTimestamp(new Date(20000)) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutTimestamp(5000) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) } - test("GroupState - illegal params to setTimeout****") { + test("GroupState - illegal params to setTimeout") { var state: GroupStateImpl[Int] = null // Test setTimeout****() with illegal values def testIllegalTimeout(body: => Unit): Unit = { - intercept[IllegalArgumentException] { body } + intercept[IllegalArgumentException] { + body + } assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) - testIllegalTimeout { state.setTimeoutDuration(-1000) } - testIllegalTimeout { state.setTimeoutDuration(0) } - testIllegalTimeout { state.setTimeoutDuration("-2 second") } - testIllegalTimeout { state.setTimeoutDuration("-1 month") } - testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } + state = GroupStateImpl.createForStreaming( + Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + testIllegalTimeout { + state.setTimeoutDuration(-1000) + } + testIllegalTimeout { + state.setTimeoutDuration(0) + } + testIllegalTimeout { + state.setTimeoutDuration("-2 second") + } + testIllegalTimeout { + state.setTimeoutDuration("-1 month") + } + testIllegalTimeout { + state.setTimeoutDuration("1 month -1 day") + } - state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) - testIllegalTimeout { state.setTimeoutTimestamp(-10000) } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } + state = GroupStateImpl.createForStreaming( + Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + testIllegalTimeout { + state.setTimeoutTimestamp(-10000) + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "-3 second") + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "-1 month") + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "1 month -1 day") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000)) + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "-3 second") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "-1 month") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") + } } test("GroupState - hasTimedOut") { for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + // for streaming queries for (initState <- Seq(None, Some(5))) { - val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + val state1 = GroupStateImpl.createForStreaming( + initState, 1000, 1000, timeoutConf, hasTimedOut = false) assert(state1.hasTimedOut === false) - val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + + val state2 = GroupStateImpl.createForStreaming( + initState, 1000, 1000, timeoutConf, hasTimedOut = true) assert(state2.hasTimedOut === true) } + + // for batch queries + assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false) } } test("GroupState - primitive type") { - var intState = new GroupStateImpl[Int](None) + var intState = GroupStateImpl.createForStreaming[Int]( + None, 1000, 1000, NoTimeout, hasTimedOut = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) - intState = new GroupStateImpl[Int](Some(10)) + intState = GroupStateImpl.createForStreaming[Int]( + Some(10), 1000, 1000, NoTimeout, hasTimedOut = false) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -210,7 +295,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" @@ -318,6 +402,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } + // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), + // Try to remove these cases in the future + for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { + val testName = + if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = Some(5), + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { state.setTimeoutTimestamp(10000) }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, + timeoutConf = EventTimeTimeout, + priorState = Some(5), + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + } + // Tests for StateStoreUpdater.updateStateForTimedOutKeys() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { @@ -386,22 +508,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("StateStoreUpdater - rows are cloned before writing to StateStore") { - // function for running count - val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { - state.update(state.getOption.getOrElse(0) + values.size) - Iterator.empty - } - val store = newStateStore() - val plan = newFlatMapGroupsWithStateExec(func) - val updater = new plan.StateStoreUpdater(store) - val data = Seq(1, 1, 2) - val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow)) - returnIter.size // consume the iterator to force store updates - val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet - assert(storeData === Set((1, 2), (2, 1))) - } - test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything @@ -558,7 +664,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(ProcessingTime("1 second"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), CheckLastBatch(("a", "1")), @@ -589,7 +695,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } - test("flatMapGroupsWithState - streaming with event time timeout") { + test("flatMapGroupsWithState - streaming with event time timeout + watermark") { // Function to maintain the max event time // Returns the max event time in the state, or -1 if the state was removed by timeout val stateFunc = ( @@ -623,7 +729,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(ProcessingTime("1 second")), + StartStream(Trigger.ProcessingTime("1 second")), AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s AddData(inputData, ("a", 4)), // Add data older than watermark for "a" @@ -677,15 +783,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + // Test the following + // - no initial state + // - timeouts operations work, does not throw any error [SPARK-20792] + // - works with primitive state type + val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") + state.setTimeoutTimestamp(0, "1 hour") + state.update(10) (key, values.size) } checkAnswer( spark.createDataset(Seq("a", "a", "b")) .groupByKey(x => x) - .mapGroupsWithState(stateFunc) + .mapGroupsWithState(EventTimeTimeout)(stateFunc) .toDF, spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) } @@ -761,6 +873,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(e.getMessage === "The output mode of function should be append or update") } + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { + test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = + (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => { + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val result = + inputData.toDF().toDF("key", "time") + .selectExpr("key", "cast(time as timestamp) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, Long)] + .groupByKey(x => x._1) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckLastBatch(("a", "1")) + ) + } + } + testWithTimeout(NoTimeout) + testWithTimeout(ProcessingTimeTimeout) + def testStateUpdateWithData( testName: String, stateUpdates: GroupState[Int] => Unit, @@ -768,7 +918,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf priorState: Option[Int], priorTimeoutTimestamp: Long = NO_TIMESTAMP, expectedState: Option[Int] = None, - expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP, + expectedException: Class[_ <: Exception] = null): Unit = { if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state @@ -782,7 +933,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } testStateUpdate( testTimeoutUpdates = false, mapGroupsFunc, timeoutConf, - priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + priorState, priorTimeoutTimestamp, + expectedState, expectedTimeoutTimestamp, expectedException) } } @@ -801,9 +953,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf stateUpdates(state) Iterator.empty } + testStateUpdate( testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf, - preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp, null) } } @@ -814,7 +967,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf priorState: Option[Int], priorTimeoutTimestamp: Long, expectedState: Option[Int], - expectedTimeoutTimestamp: Long): Unit = { + expectedTimeoutTimestamp: Long, + expectedException: Class[_ <: Exception]): Unit = { val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( @@ -829,22 +983,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } // Call updating function to update state store - val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() - } else { - updater.updateStateForKeysWithData(Iterator(key)) + def callFunction() = { + val returnedIter = if (testTimeoutUpdates) { + updater.updateStateForTimedOutKeys() + } else { + updater.updateStateForKeysWithData(Iterator(key)) + } + returnedIter.size // consume the iterator to force state updates } - returnedIter.size // consumer the iterator to force state updates - - // Verify updated state in store - val updatedStateRow = store.get(key) - assert( - updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, - "final state not as expected") - if (updatedStateRow.nonEmpty) { + if (expectedException != null) { + // Call function and verify the exception type + val e = intercept[Exception] { callFunction() } + assert(e.getClass === expectedException, "Exception thrown but of the wrong type") + } else { + // Call function to update and verify updated state in store + callFunction() + val updatedStateRow = store.get(key) assert( - updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") + Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + "final state not as expected") + if (updatedStateRow != null) { + assert( + updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") + } } } @@ -902,26 +1064,20 @@ object FlatMapGroupsWithStateSuite { import scala.collection.JavaConverters._ private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) } + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } } - override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { - iterator.filter { case (k, v) => c(k, v) } + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { + map.put(key.copy(), newValue.copy()) } - - override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key)) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue) override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def remove(condition: (UnsafeRow) => Boolean): Unit = { - iterator.map(_._1).filter(condition).foreach(map.remove) - } override def commit(): Long = version + 1 override def abort(): Unit = { } override def id: StateStoreId = null override def version: Long = 0 - override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException } - override def numKeys(): Long = map.size + override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) override def hasCommitted: Boolean = true } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 894786c50e23..368c4604dfca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming trait StateStoreMetricsTest extends StreamTest { def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = - AssertOnQuery { q => + AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get assert( progressWithData.stateOperators.map(_.numRowsTotal) === total, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 01ea62a9de4d..9c901062d570 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,20 +17,25 @@ package org.apache.spark.sql.streaming -import java.io.{File, InterruptedIOException, IOException} -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} import scala.reflect.ClassTag import scala.util.control.ControlThrowable +import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider @@ -71,6 +76,43 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + + test("explain join") { + // Make a table and ensure it will be broadcast. + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val inputData = MemoryStream[Int] + val joined = inputData.toDF().join(smallTable, smallTable("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain() + } + assert(outputStream.toString.contains("StreamingRelation")) + } + + test("SPARK-20432: union one stream with itself") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a") + val unioned = df.union(df) + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = + unioned + .writeStream.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*) + } finally { + query.stop() + } + } + } + } + test("union two streams") { val inputData1 = MemoryStream[Int] val inputData2 = MemoryStream[Int] @@ -122,6 +164,33 @@ class StreamSuite extends StreamTest { assertDF(df) } + test("Within the same streaming query, one StreamingRelation should only be transformed to one " + + "StreamingExecutionRelation") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() + var query: StreamExecution = null + try { + query = + df.union(df) + .writeStream + .format("memory") + .queryName("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + query.awaitInitialization(streamingTimeout.toMillis) + val executionRelations = + query + .logicalPlan + .collect { case ser: StreamingExecutionRelation => ser } + assert(executionRelations.size === 2) + assert(executionRelations.distinct.size === 1) + } finally { + if (query != null) { + query.stop() + } + } + } + test("unsupported queries") { val streamInput = MemoryStream[Int] val batchInput = Seq(1, 2, 3).toDS() @@ -284,7 +353,9 @@ class StreamSuite extends StreamTest { override def stop(): Unit = {} } - val df = Dataset[Int](sqlContext.sparkSession, StreamingExecutionRelation(source)) + val df = Dataset[Int]( + sqlContext.sparkSession, + StreamingExecutionRelation(source, sqlContext.sparkSession)) testStream(df)( // `ExpectFailure(isFatalError = true)` verifies two things: // - Fatal errors can be propagated to `StreamingQuery.exception` and @@ -566,6 +637,105 @@ class StreamSuite extends StreamTest { assertDescContainsQueryNameAnd(batch = 2) query.stop() } + + test("should resolve the checkpoint path") { + withTempDir { dir => + val checkpointLocation = dir.getCanonicalPath + assert(!checkpointLocation.startsWith("file:/")) + val query = MemoryStream[Int].toDF + .writeStream + .option("checkpointLocation", checkpointLocation) + .format("console") + .start() + try { + val resolvedCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot + assert(resolvedCheckpointDir.startsWith("file:/")) + } finally { + query.stop() + } + } + } + + testQuietly("specify custom state store provider") { + val providerClassName = classOf[TestStateStoreProvider].getCanonicalName + withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) { + val input = MemoryStream[Int] + val df = input.toDS().groupBy().count() + val query = df.writeStream.outputMode("complete").format("memory").queryName("name").start() + input.addData(1, 2, 3) + val e = intercept[Exception] { + query.awaitTermination() + } + + assert(e.getMessage.contains(providerClassName)) + assert(e.getMessage.contains("instantiated")) + } + } + + testQuietly("custom state store provider read from offset log") { + val input = MemoryStream[Int] + val df = input.toDS().groupBy().count() + val providerConf1 = "spark.sql.streaming.stateStore.providerClass" -> + "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider" + val providerConf2 = "spark.sql.streaming.stateStore.providerClass" -> + classOf[TestStateStoreProvider].getCanonicalName + + def runQuery(queryName: String, checkpointLoc: String): Unit = { + val query = df.writeStream + .outputMode("complete") + .format("memory") + .queryName(queryName) + .option("checkpointLocation", checkpointLoc) + .start() + input.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + } + + withTempDir { dir => + val checkpointLoc1 = new File(dir, "1").getCanonicalPath + withSQLConf(providerConf1) { + runQuery("query1", checkpointLoc1) // generate checkpoints + } + + val checkpointLoc2 = new File(dir, "2").getCanonicalPath + withSQLConf(providerConf2) { + // Verify new query will use new provider that throw error on loading + intercept[Exception] { + runQuery("query2", checkpointLoc2) + } + + // Verify old query from checkpoint will still use old provider + runQuery("query1", checkpointLoc1) + } + } + } + + for (e <- Seq( + new InterruptedException, + new InterruptedIOException, + new ClosedByInterruptException, + new UncheckedIOException("test", new ClosedByInterruptException), + new ExecutionException("test", new InterruptedException), + new UncheckedExecutionException("test", new InterruptedException))) { + test(s"view ${e.getClass.getSimpleName} as a normal query stop") { + ThrowingExceptionInCreateSource.createSourceLatch = new CountDownLatch(1) + ThrowingExceptionInCreateSource.exception = e + val query = spark + .readStream + .format(classOf[ThrowingExceptionInCreateSource].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingExceptionInCreateSource.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingExceptionInCreateSource.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } + } } abstract class FakeSource extends StreamSourceProvider { @@ -604,7 +774,16 @@ class FakeDefaultSource extends FakeSource { override def getBatch(start: Option[Offset], end: Offset): DataFrame = { val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 - spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + val ds = new Dataset[java.lang.Long]( + spark.sparkSession, + Range( + startOffset, + end.asInstanceOf[LongOffset].offset + 1, + 1, + Some(spark.sparkSession.sparkContext.defaultParallelism), + isStreaming = true), + Encoders.LONG) + ds.toDF("a") } override def stop() {} @@ -671,3 +850,51 @@ object ThrowingInterruptedIOException { */ @volatile var createSourceLatch: CountDownLatch = null } + +class TestStateStoreProvider extends StateStoreProvider { + + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + storeConfs: StateStoreConf, + hadoopConf: Configuration): Unit = { + throw new Exception("Successfully instantiated") + } + + override def stateStoreId: StateStoreId = null + + override def close(): Unit = { } + + override def getStore(version: Long): StateStore = null +} + +/** A fake source that throws `ThrowingExceptionInCreateSource.exception` in `createSource` */ +class ThrowingExceptionInCreateSource extends FakeSource { + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + ThrowingExceptionInCreateSource.createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case _: InterruptedException => + throw ThrowingExceptionInCreateSource.exception + } + } +} + +object ThrowingExceptionInCreateSource { + /** + * A latch to allow the user to wait until `ThrowingExceptionInCreateSource.createSource` is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null + @volatile var exception: Exception = null +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 5bc36dd30f6d..70b39b934071 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -26,9 +26,8 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import org.scalatest.Assertions -import org.scalatest.concurrent.{Eventually, Timeouts} -import org.scalatest.concurrent.Eventually._ +import org.scalatest.{Assertions, BeforeAndAfterAll} +import org.scalatest.concurrent.{Eventually, Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -39,9 +38,10 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * A framework for implementing tests for streaming queries and sources. @@ -67,7 +67,13 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * avoid hanging forever in the case of failures. However, individual suites can change this * by overriding `streamingTimeout`. */ -trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { +trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with BeforeAndAfterAll { + + implicit val defaultSignaler: Signaler = ThreadSignaler + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() // stop the state store maintenance thread and unload store providers + } /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -161,7 +167,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { case class StartStream( trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, - additionalConfs: Map[String, String] = Map.empty) + additionalConfs: Map[String, String] = Map.empty, + checkpointLocation: String = null) extends StreamAction /** Advance the trigger clock's time manually. */ @@ -172,8 +179,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { * * @param isFatalError if this is a fatal error. If so, the error should also be caught by * UncaughtExceptionHandler. + * @param assertFailure a function to verify the error. */ case class ExpectFailure[T <: Throwable : ClassTag]( + assertFailure: Throwable => Unit = _ => {}, isFatalError: Boolean = false) extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] override def toString(): String = @@ -341,13 +350,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { """.stripMargin) } - val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath var manualClockExpectedTime = -1L + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath try { startedTest.foreach { action => logInfo(s"Processing test stream action: $action") action match { - case StartStream(trigger, triggerClock, additionalConfs) => + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => verify(currentStream == null, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], @@ -355,6 +365,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { if (triggerClock.isInstanceOf[StreamManualClock]) { manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() } + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) additionalConfs.foreach(pair => { val value = @@ -455,6 +466,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") streamThreadDeathCause = null } + ef.assertFailure(exception.getCause) } catch { case _: InterruptedException => case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => @@ -470,7 +482,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { verify(currentStream != null || lastStream != null, "cannot assert when no stream has been started") val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } case a: Assert => val streamToAssert = Option(currentStream).getOrElse(lastStream) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index f796a4cb4a39..995cea3b37d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -19,24 +19,32 @@ package org.apache.spark.sql.streaming import java.util.{Locale, TimeZone} +import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ -import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} +import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} -object FailureSinglton { +object FailureSingleton { var firstTime = true } -class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingAggregationSuite extends StateStoreMetricsTest + with BeforeAndAfterAll with Assertions { override def afterAll(): Unit = { super.afterAll() @@ -69,6 +77,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte ) } + test("count distinct") { + val inputData = MemoryStream[(Int, Seq[Int])] + + val aggregated = + inputData.toDF() + .select($"*", explode($"_2") as 'value) + .groupBy($"_1") + .agg(size(collect_set($"value"))) + .as[(Int, Int)] + + testStream(aggregated, Update)( + AddData(inputData, (1, Seq(1, 2))), + CheckLastBatch((1, 2)) + ) + } + test("simple count, complete mode") { val inputData = MemoryStream[Int] @@ -206,12 +230,12 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte testQuietly("midbatch failure") { val inputData = MemoryStream[Int] - FailureSinglton.firstTime = true + FailureSingleton.firstTime = true val aggregated = inputData.toDS() .map { i => - if (i == 4 && FailureSinglton.firstTime) { - FailureSinglton.firstTime = false + if (i == 4 && FailureSingleton.firstTime) { + FailureSingleton.firstTime = false sys.error("injected failure") } @@ -251,7 +275,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte .where('value >= current_timestamp().cast("long") - 10L) testStream(aggregated, Complete)( - StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock), // advance clock to 10 seconds, all keys retained AddData(inputData, 0L, 5L, 5L, 10L), @@ -278,7 +302,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte clock.advance(60 * 1000L) true }, - StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock), // The commit log blown, causing the last batch to re-run CheckLastBatch((20L, 1), (85L, 1)), AssertOnQuery { q => @@ -306,7 +330,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte .where($"value".cast("date") >= date_sub(current_date(), 10)) .select(($"value".cast("long") / DateTimeUtils.SECONDS_PER_DAY).cast("long"), $"count(1)") testStream(aggregated, Complete)( - StartStream(ProcessingTime("10 day"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock), // advance clock to 10 days, should retain all keys AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10), @@ -330,7 +354,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60) true }, - StartStream(ProcessingTime("10 day"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock), // Commit log blown, causing a re-run of the last batch CheckLastBatch((20L, 1), (85L, 1)), @@ -340,4 +364,201 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte CheckLastBatch((90L, 1), (100L, 1), (105L, 1)) ) } + + test("SPARK-19690: do not convert batch aggregation in streaming query to streaming") { + val streamInput = MemoryStream[Int] + val batchDF = Seq(1, 2, 3, 4, 5) + .toDF("value") + .withColumn("parity", 'value % 2) + .groupBy('parity) + .agg(count("*") as 'joinValue) + val joinDF = streamInput + .toDF() + .join(batchDF, 'value === 'parity) + + // make sure we're planning an aggregate in the first place + assert(batchDF.queryExecution.optimizedPlan match { case _: Aggregate => true }) + + testStream(joinDF, Append)( + AddData(streamInput, 0, 1, 2, 3), + CheckLastBatch((0, 0, 2), (1, 1, 3)), + AddData(streamInput, 0, 1, 2, 3), + CheckLastBatch((0, 0, 2), (1, 1, 3))) + } + + /** + * This method verifies certain properties in the SparkPlan of a streaming aggregation. + * First of all, it checks that the child of a `StateStoreRestoreExec` creates the desired + * data distribution, where the child could be an Exchange, or a `HashAggregateExec` which already + * provides the expected data distribution. + * + * The second thing it checks that the child provides the expected number of partitions. + * + * The third thing it checks that we don't add an unnecessary shuffle in-between + * `StateStoreRestoreExec` and `StateStoreSaveExec`. + */ + private def checkAggregationChain( + se: StreamExecution, + expectShuffling: Boolean, + expectedPartition: Int): Boolean = { + val executedPlan = se.lastExecution.executedPlan + val restore = executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore.child match { + case node: UnaryExecNode => + assert(node.outputPartitioning.numPartitions === expectedPartition, + "Didn't get the expected number of partitions.") + if (expectShuffling) { + assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") + } else { + assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle") + } + + case _ => fail("Expected no shuffling") + } + var reachedRestore = false + // Check that there should be no exchanges after `StateStoreRestoreExec` + executedPlan.foreachUp { p => + if (reachedRestore) { + assert(!p.isInstanceOf[Exchange], "There should be no further exchanges") + } else { + reachedRestore = p.isInstanceOf[StateStoreRestoreExec] + } + } + true + } + + test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + val inputSource = new BlockRDDBackedSource(spark) + MockSourceProvider.withMockSources(inputSource) { + // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default + // satisfies the required distributions of all aggregations. Therefore in our SparkPlan, we + // don't have any shuffling. However, `coalesce(1)` only guarantees that the RDD has at most 1 + // partition. Which means that if we have an input RDD with 0 partitions, nothing gets + // executed. Therefore the StateStore's don't save any delta files for a given trigger. This + // then leads to `FileNotFoundException`s in the subsequent batch. + // This isn't the only problem though. Once we introduce a shuffle before + // `StateStoreRestoreExec`, the input to the operator is an empty iterator. When performing + // `groupBy().agg(...)`, `HashAggregateExec` returns a `0` value for all aggregations. If + // we fail to restore the previous state in `StateStoreRestoreExec`, we save the 0 value in + // `StateStoreSaveExec` losing all previous state. + val aggregated: Dataset[Long] = + spark.readStream.format((new MockSourceProvider).getClass.getCanonicalName) + .load().coalesce(1).groupBy().count().as[Long] + + testStream(aggregated, Complete())( + AddBlockData(inputSource, Seq(1)), + CheckLastBatch(1), + AssertOnQuery("Verify no shuffling") { se => + checkAggregationChain(se, expectShuffling = false, 1) + }, + AddBlockData(inputSource), // create an empty trigger + CheckLastBatch(1), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain(se, expectShuffling = true, 1) + }, + AddBlockData(inputSource, Seq(2, 3)), + CheckLastBatch(3), + AddBlockData(inputSource), + CheckLastBatch(3), + StopStream + ) + } + } + + test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + + "has non-empty grouping keys") { + val inputSource = new BlockRDDBackedSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + + // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default + // satisfies the required distributions of all aggregations. However, when we have + // non-empty grouping keys, in streaming, we must repartition to + // `spark.sql.shuffle.partitions`, otherwise only a single StateStore is used to process + // all keys. This may be fine, however, if the user removes the coalesce(1) or changes to + // a `coalesce(2)` for example, then the default behavior is to shuffle to + // `spark.sql.shuffle.partitions` many StateStores. When this happens, all StateStore's + // except 1 will be missing their previous delta files, which causes the stream to fail + // with FileNotFoundException. + def createDf(partitions: Int): Dataset[(Long, Long)] = { + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load().coalesce(partitions).groupBy('a % 1).count().as[(Long, Long)] + } + + testStream(createDf(1), Complete())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddBlockData(inputSource, Seq(1)), + CheckLastBatch((0L, 1L)), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain( + se, + expectShuffling = true, + spark.sessionState.conf.numShufflePartitions) + }, + StopStream + ) + + testStream(createDf(2), Complete())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + Execute(se => se.processAllAvailable()), + AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)), + CheckLastBatch((0L, 4L)), + AssertOnQuery("Verify no exchange added") { se => + checkAggregationChain( + se, + expectShuffling = false, + spark.sessionState.conf.numShufflePartitions) + }, + AddBlockData(inputSource), + CheckLastBatch((0L, 4L)), + StopStream + ) + } + } + } + + /** Add blocks of data to the `BlockRDDBackedSource`. */ + case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + source.addBlocks(data: _*) + (source, LongOffset(source.counter)) + } + } + + /** + * A Streaming Source that is backed by a BlockRDD and that can create RDDs with 0 blocks at will. + */ + class BlockRDDBackedSource(spark: SparkSession) extends Source { + var counter = 0L + private val blockMgr = SparkEnv.get.blockManager + private var blocks: Seq[BlockId] = Seq.empty + + def addBlocks(dataBlocks: Seq[Int]*): Unit = synchronized { + dataBlocks.foreach { data => + val id = TestBlockId(counter.toString) + blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) + blocks ++= id :: Nil + counter += 1 + } + counter += 1 + } + + override def getOffset: Option[Offset] = synchronized { + if (counter == 0) None else Some(LongOffset(counter)) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { + val rdd = new BlockRDD[Int](spark.sparkContext, blocks.toArray) + .map(i => InternalRow(i)) // we don't really care about the values in this test + blocks = Seq.empty + spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF() + } + override def schema: StructType = MockSourceProvider.fakeSchema + override def stop(): Unit = { + blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala new file mode 100644 index 000000000000..533e1165fd59 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -0,0 +1,472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.UUID + +import scala.util.Random + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} +import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + + +class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + import testImplicits._ + test("stream stream inner join on non-time column") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as "key", ('value * 2) as "leftValue") + val df2 = input2.toDF.select('value as "key", ('value * 3) as "rightValue") + val joined = df1.join(df2, "key") + + testStream(joined)( + AddData(input1, 1), + CheckAnswer(), + AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join + CheckLastBatch((1, 2, 3)), + AddData(input1, 10), // 10 arrived on input2 first, then input1, should join + CheckLastBatch((10, 20, 30)), + AddData(input2, 1), // another 1 in input2 should join with 1 input1 + CheckLastBatch((1, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) + CheckLastBatch((1, 2, 3), (1, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 100), + AddData(input2, 100), + CheckLastBatch((100, 200, 300)) + ) + } + + test("stream stream inner join on windows - without watermark") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF + .select('value as "key", 'value.cast("timestamp") as "timestamp", ('value * 2) as "leftValue") + .select('key, window('timestamp, "10 second"), 'leftValue) + + val df2 = input2.toDF + .select('value as "key", 'value.cast("timestamp") as "timestamp", + ('value * 3) as "rightValue") + .select('key, window('timestamp, "10 second"), 'rightValue) + + val joined = df1.join(df2, Seq("key", "window")) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(input1, 1), + CheckLastBatch(), + AddData(input2, 1), + CheckLastBatch((1, 10, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 25), + CheckLastBatch(), + StopStream, + StartStream(), + AddData(input2, 25), + CheckLastBatch((25, 30, 50, 75)), + AddData(input1, 1), + CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is no watermark + StopStream, + StartStream(), + AddData(input1, 5), + CheckLastBatch(), + AddData(input2, 5), + CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark + ) + } + + test("stream stream inner join on windows - with watermark") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF + .select('value as "key", 'value.cast("timestamp") as "timestamp", ('value * 2) as "leftValue") + .withWatermark("timestamp", "10 seconds") + .select('key, window('timestamp, "10 second"), 'leftValue) + + val df2 = input2.toDF + .select('value as "key", 'value.cast("timestamp") as "timestamp", + ('value * 3) as "rightValue") + .select('key, window('timestamp, "10 second"), 'rightValue) + + val joined = df1.join(df2, Seq("key", "window")) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(input1, 1), + CheckAnswer(), + assertNumStateRows(total = 1, updated = 1), + + AddData(input2, 1), + CheckLastBatch((1, 10, 2, 3)), + assertNumStateRows(total = 2, updated = 1), + StopStream, + StartStream(), + + AddData(input1, 25), + CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15 + assertNumStateRows(total = 3, updated = 1), + + AddData(input2, 25), + CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10] + assertNumStateRows(total = 2, updated = 1), + StopStream, + StartStream(), + + AddData(input2, 1), + CheckLastBatch(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + + AddData(input1, 5), + CheckLastBatch(), // Should not join or add to state as < 15 got filtered by watermark + assertNumStateRows(total = 2, updated = 0) + ) + } + + test("stream stream inner join with time range - with watermark - one side condition") { + import org.apache.spark.sql.functions._ + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = + df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - interval 5 seconds")) + .select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + + testStream(joined)( + AddData(leftInput, (1, 5)), + CheckAnswer(), + AddData(rightInput, (1, 11)), + CheckLastBatch((1, 5, 11)), + AddData(rightInput, (1, 10)), + CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 + assertNumStateRows(total = 3, updated = 1), + + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 3), (1, 30)), + CheckLastBatch((1, 3, 10), (1, 3, 11)), + assertNumStateRows(total = 5, updated = 2), + AddData(rightInput, (0, 30)), + CheckLastBatch(), + assertNumStateRows(total = 6, updated = 1), + + // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 + + // Run another batch with event time = 25 to clear right state where rightTime <= 25 + AddData(rightInput, (0, 30)), + CheckLastBatch(), + assertNumStateRows(total = 5, updated = 1), // removed (1, 11) and (1, 10), added (0, 30) + + // New data to right input should match with left side (1, 3) and (1, 5), as left state should + // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and + // state rows with rightTime <= 25 should be removed from state. + // (1, 20) ==> filtered by event time watermark = 20 + // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state + // as state watermark = 25 + // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state + AddData(rightInput, (1, 20), (1, 21), (1, 28)), + CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 6, updated = 1), + + // New data to left input with leftTime <= 20 should be filtered due to event time watermark + AddData(leftInput, (1, 20), (1, 21)), + CheckLastBatch((1, 21, 28)), + assertNumStateRows(total = 7, updated = 1) + ) + } + + test("stream stream inner join with time range - with watermark - two side conditions") { + import org.apache.spark.sql.functions._ + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "20 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "30 seconds") + + val condition = expr( + "leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 10 seconds AND rightTime + interval 5 seconds") + + // This translates to leftTime <= rightTime + 5 seconds AND leftTime >= rightTime - 10 seconds + // So given leftTime, rightTime has to be BETWEEN leftTime - 5 seconds AND leftTime + 10 seconds + // + // =============== * ======================== * ============================== * ==> leftTime + // | | | + // |<---- 5s -->|<------ 10s ------>| |<------ 10s ------>|<---- 5s -->| + // | | | + // == * ============================== * =========>============== * ===============> rightTime + // + // E.g. + // if rightTime = 60, then it matches only leftTime = [50, 65] + // if leftTime = 20, then it match only with rightTime = [15, 30] + // + // State value predicates + // left side: + // values allowed: leftTime >= rightTime - 10s ==> leftTime > eventTimeWatermark - 10 + // drop state where leftTime < eventTime - 10 + // right side: + // values allowed: rightTime >= leftTime - 5s ==> rightTime > eventTimeWatermark - 5 + // drop state where rightTime < eventTime - 5 + + val joined = + df1.join(df2, condition).select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + + testStream(joined)( + // If leftTime = 20, then it match only with rightTime = [15, 30] + AddData(leftInput, (1, 20)), + CheckAnswer(), + AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), + CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + assertNumStateRows(total = 7, updated = 6), + + // If rightTime = 60, then it matches only leftTime = [50, 65] + AddData(rightInput, (1, 60)), + CheckLastBatch(), // matches with nothing on the left + AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), + CheckLastBatch((1, 50, 60), (1, 65, 60)), + assertNumStateRows(total = 12, updated = 4), + + // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 + // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) + // Should drop < 20 from left, i.e., none + // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) + // Should drop < 25 from the right, i.e., 14 and 15 + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to stat + CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1), // 12 - 2 removed + 1 added + + // Advance the watermark + AddData(rightInput, (1, 80)), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 1), + + // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 + // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) + // Should drop < 36 from left, i.e., 20, 31 (30 was not added) + // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) + // Should drop < 41 from the right, i.e., 25, 26, 30, 31 + AddData(rightInput, (1, 50)), + CheckLastBatch((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1) // 12 - 6 removed + 1 added + ) + } + + testQuietly("stream stream inner join without equality predicate") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") + val df2 = input2.toDF.select('value as "rightKey", ('value * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey < rightKey")) + val e = intercept[Exception] { + val q = joined.writeStream.format("memory").queryName("test").start() + input1.addData(1) + q.awaitTermination(10000) + } + assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) + } + + testQuietly("extract watermark from time condition") { + val attributesToFindConstraintFor = Seq( + AttributeReference("leftTime", TimestampType)(), + AttributeReference("leftOther", IntegerType)()) + val metadataWithWatermark = new MetadataBuilder() + .putLong(EventTimeWatermark.delayKey, 1000) + .build() + val attributesWithWatermark = Seq( + AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), + AttributeReference("rightOther", IntegerType)()) + + def watermarkFrom( + conditionStr: String, + rightWatermark: Option[Long] = Some(10000)): Option[Long] = { + val conditionExpr = Some(conditionStr).map { str => + val plan = + Filter( + spark.sessionState.sqlParser.parseExpression(str), + LogicalRDD( + attributesToFindConstraintFor ++ attributesWithWatermark, + spark.sparkContext.emptyRDD)(spark)) + plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition + } + StreamingSymmetricHashJoinHelper.getStateValueWatermark( + AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), + conditionExpr, rightWatermark) + } + + // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, + // then cannot define constraint on leftTime. + assert(watermarkFrom("leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) + assert(watermarkFrom("leftTime < rightTime") === None) + assert(watermarkFrom("leftTime <= rightTime") === None) + assert(watermarkFrom("rightTime > leftTime") === None) + assert(watermarkFrom("rightTime >= leftTime") === None) + assert(watermarkFrom("rightTime < leftTime") === Some(10000)) + assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) + + // Test type conversions + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) + + // Test with timestamp type + calendar interval on either side of equation + // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. + assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) + assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) + assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) + assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) + assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") + === Some(12000)) + + // Test with casted long type + constants on either side of equation + // Note: long type and constants commute, so more combinations to test. + // -- Constants on the right + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") + === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") + === Some(10100)) + assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") + === Some(10200)) + // -- Constants on the left + assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) + assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) + assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") + === Some(7000)) + assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") + === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") + === Some(10100)) + // -- Constants on both sides, mixed types + assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") + === Some(13000)) + + // Test multiple conditions, should return minimum watermark + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === + Some(7000)) // first condition wins + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === + Some(6000)) // second condition wins + + // Test invalid comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes + assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed + + // Test static comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) + } + + test("locality preferences of StateStoreAwareZippedRDD") { + import StreamingSymmetricHashJoinHelper._ + + withTempDir { tempDir => + val queryId = UUID.randomUUID + val opId = 0 + val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString + val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L) + + implicit val sqlContext = spark.sqlContext + val coordinatorRef = sqlContext.streams.stateStoreCoordinator + val numPartitions = 5 + val storeNames = Seq("name1", "name2") + + val partitionAndStoreNameToLocation = { + for (partIndex <- 0 until numPartitions; storeName <- storeNames) yield { + (partIndex, storeName) -> s"host-$partIndex-$storeName" + } + }.toMap + partitionAndStoreNameToLocation.foreach { case ((partIndex, storeName), hostName) => + val providerId = StateStoreProviderId(stateInfo, partIndex, storeName) + coordinatorRef.reportActiveInstance(providerId, hostName, s"exec-$hostName") + require( + coordinatorRef.getLocation(providerId) === + Some(ExecutorCacheTaskLocation(hostName, s"exec-$hostName").toString)) + } + + val rdd1 = spark.sparkContext.makeRDD(1 to 10, numPartitions) + val rdd2 = spark.sparkContext.makeRDD((1 to 10).map(_.toString), numPartitions) + val rdd = rdd1.stateStoreAwareZipPartitions(rdd2, stateInfo, storeNames, coordinatorRef) { + (left, right) => left.zip(right) + } + require(rdd.partitions.length === numPartitions) + for (partIndex <- 0 until numPartitions) { + val expectedLocations = storeNames.map { storeName => + val hostName = partitionAndStoreNameToLocation((partIndex, storeName)) + ExecutorCacheTaskLocation(hostName, s"exec-$hostName").toString + }.toSet + assert(rdd.preferredLocations(rdd.partitions(partIndex)).toSet === expectedLocations) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b8a694c17731..1fe639fcf284 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.streaming import java.util.UUID import scala.collection.mutable -import scala.concurrent.duration._ +import scala.language.reflectiveCalls import org.scalactic.TolerantNumerics -import org.scalatest.concurrent.AsyncAssertions.Waiter -import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester._ +import org.scalatest.concurrent.AsyncAssertions.Waiter +import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.apache.spark.SparkException import org.apache.spark.scheduler._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index b49efa689023..46eec736d402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -78,9 +78,9 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { eventually(Timeout(streamingTimeout)) { require(!q2.isActive) require(q2.exception.isDefined) + assert(spark.streams.get(q2.id) === null) + assert(spark.streams.active.toSet === Set(q3)) } - assert(spark.streams.get(q2.id) === null) - assert(spark.streams.active.toSet === Set(q3)) } } @@ -289,7 +289,7 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { } } - AwaitTerminationTester.test(expectedBehavior, awaitTermFunc, testBehaviorFor) + AwaitTerminationTester.test(expectedBehavior, () => awaitTermFunc(), testBehaviorFor) } /** Stop a random active query either with `stop()` or with an error */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 901cf34f289c..79bb827e0de9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -33,22 +33,17 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { - implicit class EqualsIgnoreCRLF(source: String) { - def equalsIgnoreCRLF(target: String): Boolean = { - source.replaceAll("\r\n|\r|\n", System.lineSeparator) === - target.replaceAll("\r\n|\r|\n", System.lineSeparator) - } - } - test("StreamingQueryProgress - prettyJson") { val json1 = testProgress1.prettyJson - assert(json1.equalsIgnoreCRLF( + assertJson( + json1, s""" |{ | "id" : "${testProgress1.id.toString}", | "runId" : "${testProgress1.runId.toString}", | "name" : "myName", | "timestamp" : "2016-12-05T20:54:20.827Z", + | "batchId" : 2, | "numInputRows" : 678, | "inputRowsPerSecond" : 10.0, | "durationMs" : { @@ -62,7 +57,8 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | }, | "stateOperators" : [ { | "numRowsTotal" : 0, - | "numRowsUpdated" : 1 + | "numRowsUpdated" : 1, + | "memoryUsedBytes" : 2 | } ], | "sources" : [ { | "description" : "source", @@ -75,25 +71,27 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "description" : "sink" | } |} - """.stripMargin.trim)) + """.stripMargin.trim) assert(compact(parse(json1)) === testProgress1.json) val json2 = testProgress2.prettyJson - assert( - json2.equalsIgnoreCRLF( - s""" + assertJson( + json2, + s""" |{ | "id" : "${testProgress2.id.toString}", | "runId" : "${testProgress2.runId.toString}", | "name" : null, | "timestamp" : "2016-12-05T20:54:20.827Z", + | "batchId" : 2, | "numInputRows" : 678, | "durationMs" : { | "total" : 0 | }, | "stateOperators" : [ { | "numRowsTotal" : 0, - | "numRowsUpdated" : 1 + | "numRowsUpdated" : 1, + | "memoryUsedBytes" : 2 | } ], | "sources" : [ { | "description" : "source", @@ -105,7 +103,7 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "description" : "sink" | } |} - """.stripMargin.trim)) + """.stripMargin.trim) assert(compact(parse(json2)) === testProgress2.json) } @@ -121,14 +119,15 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { test("StreamingQueryStatus - prettyJson") { val json = testStatus.prettyJson - assert(json.equalsIgnoreCRLF( + assertJson( + json, """ |{ | "message" : "active", | "isDataAvailable" : true, | "isTriggerActive" : false |} - """.stripMargin.trim)) + """.stripMargin.trim) } test("StreamingQueryStatus - json") { @@ -209,6 +208,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { } } } + + def assertJson(source: String, expected: String): Unit = { + assert( + source.replaceAll("\r\n|\r|\n", System.lineSeparator) === + expected.replaceAll("\r\n|\r|\n", System.lineSeparator)) + } } object StreamingQueryStatusAndProgressSuite { @@ -224,7 +229,8 @@ object StreamingQueryStatusAndProgressSuite { "min" -> "2016-12-05T20:54:20.827Z", "avg" -> "2016-12-05T20:54:20.827Z", "watermark" -> "2016-12-05T20:54:20.827Z").asJava), - stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)), + stateOperators = Array(new StateOperatorProgress( + numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2)), sources = Array( new SourceProgress( description = "source", @@ -247,7 +253,8 @@ object StreamingQueryStatusAndProgressSuite { durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava), // empty maps should be handled correctly eventTime = new java.util.HashMap(Map.empty[String, String].asJava), - stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)), + stateOperators = Array(new StateOperatorProgress( + numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2)), sources = Array( new SourceProgress( description = "source", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index b69536ed3746..ab35079dca23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -22,19 +22,19 @@ import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils import org.mockito.Mockito._ import org.scalactic.TolerantNumerics -import org.scalatest.concurrent.Eventually._ import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.ManualClock @@ -466,7 +466,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckAnswer(6, 3, 6, 3, 1, 1), AssertOnQuery("metadata log should contain only two files") { q => - val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) + val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toUri) val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) val toTest = logFileNames.filter(!_.endsWith(".crc")).sorted // Workaround for SPARK-17475 assert(toTest.size == 2 && toTest.head == "1") @@ -492,7 +492,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckAnswer(1, 2, 1, 2, 3, 4, 5, 6, 7, 8), AssertOnQuery("metadata log should contain three files") { q => - val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) + val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toUri) val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) val toTest = logFileNames.filter(!_.endsWith(".crc")).sorted // Workaround for SPARK-17475 assert(toTest.size == 3 && toTest.head == "2") @@ -613,6 +613,45 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("get the query id in source") { + @volatile var queryId: String = null + val source = new Source { + override def stop(): Unit = {} + override def getOffset: Option[Offset] = { + queryId = spark.sparkContext.getLocalProperty(StreamExecution.QUERY_ID_KEY) + None + } + override def getBatch(start: Option[Offset], end: Offset): DataFrame = spark.emptyDataFrame + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source) { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + testStream(df)( + AssertOnQuery { sq => + sq.processAllAvailable() + assert(sq.id.toString === queryId) + assert(sq.runId.toString !== queryId) + true + } + ) + } + } + + test("processAllAvailable should not block forever when a query is stopped") { + val input = MemoryStream[Int] + input.addData(1) + val query = input.toDF().writeStream + .trigger(Trigger.Once()) + .format("console") + .start() + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) @@ -620,10 +659,13 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val source = new Source() { override def schema: StructType = triggerDF.schema override def getOffset: Option[Offset] = Some(LongOffset(0)) - override def getBatch(start: Option[Offset], end: Offset): DataFrame = triggerDF + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + sqlContext.internalCreateDataFrame( + triggerDF.queryExecution.toRdd, triggerDF.schema, isStreaming = true) + } override def stop(): Unit = {} } - StreamingExecutionRelation(source) + StreamingExecutionRelation(source, spark) } /** Returns the query progress at the end of the first trigger of streaming DF */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index dc2506a48ad0..aa163d2211c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -88,7 +88,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { override def getBatch(start: Option[Offset], end: Offset): DataFrame = { import spark.implicits._ - Seq[Int]().toDS().toDF() + spark.internalCreateDataFrame(spark.sparkContext.emptyRDD, schema, isStreaming = true) } override def stop() {} @@ -378,14 +378,14 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { verify(LastOptions.mockStreamSourceProvider).createSource( any(), - meq(s"$checkpointLocationURI/sources/0"), + meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/0"), meq(None), meq("org.apache.spark.sql.streaming.test"), meq(Map.empty)) verify(LastOptions.mockStreamSourceProvider).createSource( any(), - meq(s"$checkpointLocationURI/sources/1"), + meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/1"), meq(None), meq("org.apache.spark.sql.streaming.test"), meq(Map.empty)) @@ -641,8 +641,9 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { test("temp checkpoint dir should be deleted if a query is stopped without errors") { import testImplicits._ val query = MemoryStream[Int].toDS.writeStream.format("console").start() + query.processAllAvailable() val checkpointDir = new Path( - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.checkpointRoot) + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot) val fs = checkpointDir.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.exists(checkpointDir)) query.stop() @@ -654,7 +655,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { val input = MemoryStream[Int] val query = input.toDS.map(_ / 0).writeStream.format("console").start() val checkpointDir = new Path( - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.checkpointRoot) + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot) val fs = checkpointDir.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.exists(checkpointDir)) input.addData(1) @@ -663,4 +664,16 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } assert(fs.exists(checkpointDir)) } + + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { + spark.readStream + .format("org.apache.spark.sql.streaming.test") + .schema("aa INT") + .load() + + assert(LastOptions.schema.isDefined) + assert(LastOptions.schema.get === StructType(StructField("aa", IntegerType) :: Nil)) + + LastOptions.clear() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index fb15e7def6db..569bac156b53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -128,6 +129,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be import testImplicits._ private val userSchema = new StructType().add("s", StringType) + private val userSchemaString = "s STRING" private val textSchema = new StructType().add("value", StringType) private val data = Seq("1", "2", "3") private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath @@ -678,4 +680,99 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(e.contains("User specified schema not supported with `table`")) } } + + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { + spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) + testRead(spark.read.schema(userSchemaString).text(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchemaString).text(dir), data, userSchema) + testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema) + testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema) + } + + test("SPARK-20460 Check name duplication in buckets") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF("col", c0).write.bucketBy(2, c0, c1).saveAsTable("t") + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the bucket definition")) + + errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF("col", c0).write.bucketBy(2, "col").sortBy(c0, c1).saveAsTable("t") + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the sort definition")) + } + } + } + + test("SPARK-20460 Check name duplication in schema") { + def checkWriteDataColumnDuplication( + format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF(colName0, colName1).write.format(format).mode("overwrite") + .save(tempDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) when inserting into")) + } + + def checkReadUserSpecifiedDataColumnDuplication( + df: DataFrame, format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val testDir = Utils.createTempDir(tempDir.getAbsolutePath) + df.write.format(format).mode("overwrite").save(testDir.getAbsolutePath) + val errorMsg = intercept[AnalysisException] { + spark.read.format(format).schema(s"$colName0 INT, $colName1 INT") + .load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema:")) + } + + def checkReadPartitionColumnDuplication( + format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val testDir = Utils.createTempDir(tempDir.getAbsolutePath) + Seq(1).toDF("col").write.format(format).mode("overwrite") + .save(s"${testDir.getAbsolutePath}/$colName0=1/$colName1=1") + val errorMsg = intercept[AnalysisException] { + spark.read.format(format).load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the partition schema:")) + } + + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withTempDir { src => + // Check CSV format + checkWriteDataColumnDuplication("csv", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "csv", c0, c1, src) + // If `inferSchema` is true, a CSV format is duplicate-safe (See SPARK-16896) + var testDir = Utils.createTempDir(src.getAbsolutePath) + Seq("a,a", "1,1").toDF().coalesce(1).write.mode("overwrite").text(testDir.getAbsolutePath) + val df = spark.read.format("csv").option("inferSchema", true).option("header", true) + .load(testDir.getAbsolutePath) + checkAnswer(df, Row(1, 1)) + checkReadPartitionColumnDuplication("csv", c0, c1, src) + + // Check JSON format + checkWriteDataColumnDuplication("json", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "json", c0, c1, src) + // Inferred schema cases + testDir = Utils.createTempDir(src.getAbsolutePath) + Seq(s"""{"$c0":3, "$c1":5}""").toDF().write.mode("overwrite") + .text(testDir.getAbsolutePath) + val errorMsg = intercept[AnalysisException] { + spark.read.format("json").option("inferSchema", true).load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema:")) + checkReadPartitionColumnDuplication("json", c0, c1, src) + + // Check Parquet format + checkWriteDataColumnDuplication("parquet", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "parquet", c0, c1, src) + checkReadPartitionColumnDuplication("parquet", c0, c1, src) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index f9b3ff840582..0cfe260e5215 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} -import org.apache.spark.sql.internal.SQLConf /** * A collection of sample data used in SQL tests. @@ -29,8 +28,6 @@ import org.apache.spark.sql.internal.SQLConf private[sql] trait SQLTestData { self => protected def spark: SparkSession - protected def sqlConf: SQLConf = spark.sessionState.conf - // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.spark.sqlContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 44c0fc70d066..a14a1441a431 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -32,12 +32,14 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE -import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{UninterruptibleThread, Utils} /** @@ -53,7 +55,8 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} private[sql] trait SQLTestUtils extends SparkFunSuite with Eventually with BeforeAndAfterAll - with SQLTestData { self => + with SQLTestData + with PlanTest { self => protected def sparkContext = spark.sparkContext @@ -89,28 +92,9 @@ private[sql] trait SQLTestUtils } } - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (spark.conf.contains(key)) { - Some(spark.conf.get(key)) - } else { - None - } - } - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + SparkSession.setActiveSession(spark) + super.withSQLConf(pairs: _*)(f) } /** @@ -149,6 +133,7 @@ private[sql] trait SQLTestUtils .getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -164,6 +149,19 @@ private[sql] trait SQLTestUtils } } + /** + * Creates the specified number of temporary directories, which is then passed to `f` and will be + * deleted after `f` returns. + */ + protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { + val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) + try f(files) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + files.foreach(Utils.deleteRecursively) + } + } + /** * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). */ @@ -237,7 +235,7 @@ private[sql] trait SQLTestUtils try f(dbName) finally { if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE ${DEFAULT_DATABASE}") + spark.sql(s"USE $DEFAULT_DATABASE") } spark.sql(s"DROP DATABASE $dbName CASCADE") } @@ -249,8 +247,9 @@ private[sql] trait SQLTestUtils protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { try f finally { dbNames.foreach { name => - spark.sql(s"DROP DATABASE IF EXISTS $name") + spark.sql(s"DROP DATABASE IF EXISTS $name CASCADE") } + spark.sql(s"USE $DEFAULT_DATABASE") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 81c69a338abc..cd8d0708d8a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -24,6 +24,7 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. @@ -31,7 +32,10 @@ import org.apache.spark.sql.{SparkSession, SQLContext} trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { protected def sparkConf = { - new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") } /** @@ -74,6 +78,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventua protected override def afterAll(): Unit = { super.afterAll() if (_spark != null) { + _spark.sessionState.catalog.reset() _spark.stop() _spark = null } @@ -86,6 +91,8 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventua protected override def afterEach(): Unit = { super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() // files can be closed from other threads, so wait a bit // normally this doesn't take more than 1s eventually(timeout(10.seconds)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 7c9ea7d39363..a239e39d9c5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.{functions, AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} -import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand} +import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand} +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.test.SharedSQLContext class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { @@ -178,26 +179,28 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.range(10).write.format("json").save(path.getCanonicalPath) assert(commands.length == 1) assert(commands.head._1 == "save") - assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand]) - assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json") + assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] + .fileFormat.isInstanceOf[JsonFileFormat]) } withTable("tab") { - sql("CREATE TABLE tab(i long) using parquet") + sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via onSuccess spark.range(10).write.insertInto("tab") - assert(commands.length == 2) - assert(commands(1)._1 == "insertInto") - assert(commands(1)._2.isInstanceOf[InsertIntoTable]) - assert(commands(1)._2.asInstanceOf[InsertIntoTable].table + assert(commands.length == 3) + assert(commands(2)._1 == "insertInto") + assert(commands(2)._2.isInstanceOf[InsertIntoTable]) + assert(commands(2)._2.asInstanceOf[InsertIntoTable].table .asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab") } + // exiting withTable adds commands(3) via onSuccess (drops tab) withTable("tab") { spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") - assert(commands.length == 3) - assert(commands(2)._1 == "saveAsTable") - assert(commands(2)._2.isInstanceOf[CreateTable]) - assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) + assert(commands.length == 5) + assert(commands(4)._1 == "saveAsTable") + assert(commands(4)._2.isInstanceOf[CreateTable]) + assert(commands(4)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) } withTable("tab") { diff --git a/sql/create-docs.sh b/sql/create-docs.sh new file mode 100755 index 000000000000..4353708d22f7 --- /dev/null +++ b/sql/create-docs.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Script to create SQL API docs. This requires `mkdocs` and to build +# Spark first. After running this script the html docs can be found in +# $SPARK_HOME/sql/site + +set -o pipefail +set -e + +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +SPARK_HOME="$(cd "`dirname "${BASH_SOURCE[0]}"`"/..; pwd)" + +if ! hash python 2>/dev/null; then + echo "Missing python in your path, skipping SQL documentation generation." + exit 0 +fi + +if ! hash mkdocs 2>/dev/null; then + echo "Missing mkdocs in your path, trying to install mkdocs for SQL documentation generation." + pip install mkdocs +fi + +pushd "$FWDIR" > /dev/null + +# Now create the markdown file +rm -fr docs +mkdir docs +echo "Generating markdown files for SQL documentation." +"$SPARK_HOME/bin/spark-submit" gen-sql-markdown.py + +# Now create the HTML files +echo "Generating HTML files for SQL documentation." +mkdocs build --clean +rm -fr docs + +popd diff --git a/sql/gen-sql-markdown.py b/sql/gen-sql-markdown.py new file mode 100644 index 000000000000..fa8124b4513a --- /dev/null +++ b/sql/gen-sql-markdown.py @@ -0,0 +1,197 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import os +from collections import namedtuple + +ExpressionInfo = namedtuple( + "ExpressionInfo", "className name usage arguments examples note since") + + +def _list_function_infos(jvm): + """ + Returns a list of function information via JVM. Sorts wrapped expression infos by name + and returns them. + """ + + jinfos = jvm.org.apache.spark.sql.api.python.PythonSQLUtils.listBuiltinFunctionInfos() + infos = [] + for jinfo in jinfos: + name = jinfo.getName() + usage = jinfo.getUsage() + usage = usage.replace("_FUNC_", name) if usage is not None else usage + infos.append(ExpressionInfo( + className=jinfo.getClassName(), + name=name, + usage=usage, + arguments=jinfo.getArguments().replace("_FUNC_", name), + examples=jinfo.getExamples().replace("_FUNC_", name), + note=jinfo.getNote(), + since=jinfo.getSince())) + return sorted(infos, key=lambda i: i.name) + + +def _make_pretty_usage(usage): + """ + Makes the usage description pretty and returns a formatted string if `usage` + is not an empty string. Otherwise, returns None. + """ + + if usage is not None and usage.strip() != "": + usage = "\n".join(map(lambda u: u.strip(), usage.split("\n"))) + return "%s\n\n" % usage + + +def _make_pretty_arguments(arguments): + """ + Makes the arguments description pretty and returns a formatted string if `arguments` + starts with the argument prefix. Otherwise, returns None. + + Expected input: + + Arguments: + * arg0 - ... + ... + * arg0 - ... + ... + + Expected output: + **Arguments:** + + * arg0 - ... + ... + * arg0 - ... + ... + + """ + + if arguments.startswith("\n Arguments:"): + arguments = "\n".join(map(lambda u: u[6:], arguments.strip().split("\n")[1:])) + return "**Arguments:**\n\n%s\n\n" % arguments + + +def _make_pretty_examples(examples): + """ + Makes the examples description pretty and returns a formatted string if `examples` + starts with the example prefix. Otherwise, returns None. + + Expected input: + + Examples: + > SELECT ...; + ... + > SELECT ...; + ... + + Expected output: + **Examples:** + + ``` + > SELECT ...; + ... + > SELECT ...; + ... + ``` + + """ + + if examples.startswith("\n Examples:"): + examples = "\n".join(map(lambda u: u[6:], examples.strip().split("\n")[1:])) + return "**Examples:**\n\n```\n%s\n```\n\n" % examples + + +def _make_pretty_note(note): + """ + Makes the note description pretty and returns a formatted string if `note` is not + an empty string. Otherwise, returns None. + + Expected input: + + ... + + Expected output: + **Note:** + + ... + + """ + + if note != "": + note = "\n".join(map(lambda n: n[4:], note.split("\n"))) + return "**Note:**\n%s\n" % note + + +def generate_sql_markdown(jvm, path): + """ + Generates a markdown file after listing the function information. The output file + is created in `path`. + + Expected output: + ### NAME + + USAGE + + **Arguments:** + + ARGUMENTS + + **Examples:** + + ``` + EXAMPLES + ``` + + **Note:** + + NOTE + + **Since:** SINCE + +
    + + """ + + with open(path, 'w') as mdfile: + for info in _list_function_infos(jvm): + name = info.name + usage = _make_pretty_usage(info.usage) + arguments = _make_pretty_arguments(info.arguments) + examples = _make_pretty_examples(info.examples) + note = _make_pretty_note(info.note) + since = info.since + + mdfile.write("### %s\n\n" % name) + if usage is not None: + mdfile.write("%s\n\n" % usage.strip()) + if arguments is not None: + mdfile.write(arguments) + if examples is not None: + mdfile.write(examples) + if note is not None: + mdfile.write(note) + if since is not None and since != "": + mdfile.write("**Since:** %s\n\n" % since.strip()) + mdfile.write("
    \n\n") + + +if __name__ == "__main__": + from pyspark.java_gateway import launch_gateway + + jvm = launch_gateway().jvm + markdown_file_path = "%s/docs/index.md" % os.path.dirname(sys.argv[0]) + generate_sql_markdown(jvm, markdown_file_path) diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index a5a8e2640586..3135a8a275da 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -63,6 +63,16 @@ ${hive.group} hive-beeline + + org.eclipse.jetty + jetty-server + provided + + + org.eclipse.jetty + jetty-servlet + provided + org.seleniumhq.selenium diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java index c3219aabfc23..f16863c1b41a 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java @@ -27,7 +27,7 @@ * */ public final class ServiceOperations { - private static final Log LOG = LogFactory.getLog(AbstractService.class); + private static final Log LOG = LogFactory.getLog(ServiceOperations.class); private ServiceOperations() { } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java index c1b3892f5206..859f9c8b449e 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java @@ -48,7 +48,7 @@ */ public class SessionManager extends CompositeService { - private static final Log LOG = LogFactory.getLog(CompositeService.class); + private static final Log LOG = LogFactory.getLog(SessionManager.class); public static final String HIVERCFILE = ".hiverc"; private HiveConf hiveConf; private final Map handleToSession = diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 5e4734ad3ad2..7442c987efc7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} @@ -45,7 +44,6 @@ import org.apache.spark.util.{ShutdownHookManager, Utils} * `HiveThriftServer2` thrift server. */ object HiveThriftServer2 extends Logging { - var LOG = LogFactory.getLog(classOf[HiveServer2]) var uiTab: Option[ThriftServerTab] = None var listener: HiveThriftServer2Listener = _ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index ff3784cab9e2..f5191fa9132b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -71,9 +71,9 @@ private[hive] class SparkExecuteStatementOperation( def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - sqlContext.sparkContext.clearJobGroup() logDebug(s"CLOSING $statementId") cleanup(OperationState.CLOSED) + sqlContext.sparkContext.clearJobGroup() } def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { @@ -253,6 +253,8 @@ private[hive] class SparkExecuteStatementOperation( return } else { setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, SparkUtils.exceptionString(e)) throw e } // Actually do need to catch Throwable as some failures don't inherit from Exception and @@ -271,9 +273,6 @@ private[hive] class SparkExecuteStatementOperation( override def cancel(): Unit = { logInfo(s"Cancel '$statement' with $statementId") - if (statementId != null) { - sqlContext.sparkContext.cancelJobGroup(statementId) - } cleanup(OperationState.CANCELED) } @@ -285,6 +284,9 @@ private[hive] class SparkExecuteStatementOperation( backgroundHandle.cancel(true) } } + if (statementId != null) { + sqlContext.sparkContext.cancelJobGroup(statementId) + } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 33e18a8da60f..832a15d09599 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -37,6 +37,8 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveUtils @@ -50,6 +52,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { private val prompt = "spark-sql" private val continuedPrompt = "".padTo(prompt.length, ' ') private var transport: TSocket = _ + private final val SPARK_HADOOP_PROP_PREFIX = "spark.hadoop." installSignalHandler() @@ -80,11 +83,17 @@ private[hive] object SparkSQLCLIDriver extends Logging { System.exit(1) } + val sparkConf = new SparkConf(loadDefaults = true) + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val extraConfigs = HiveUtils.formatTimeVarsForHiveClient(hadoopConf) + val cliConf = new HiveConf(classOf[SessionState]) - // Override the location of the metastore since this is only used for local execution. - HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false).foreach { - case (key, value) => cliConf.set(key, value) + (hadoopConf.iterator().asScala.map(kv => kv.getKey -> kv.getValue) + ++ sparkConf.getAll.toMap ++ extraConfigs).foreach { + case (k, v) => + cliConf.set(k, v) } + val sessionState = new CliSessionState(cliConf) sessionState.in = System.in @@ -134,6 +143,16 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Hive 1.2 + not supported in CLI throw new RuntimeException("Remote operations not supported") } + // Respect the configurations set by --hiveconf from the command line + // (based on Hive's CliDriver). + val hiveConfFromCmd = sessionState.getOverriddenConfigurations.entrySet().asScala + val newHiveConf = hiveConfFromCmd.map { kv => + // If the same property is configured by spark.hadoop.xxx, we ignore it and + // obey settings from spark properties + val k = kv.getKey + val v = sys.props.getOrElseUpdate(SPARK_HADOOP_PROP_PREFIX + k, kv.getValue) + (k, v) + } val cli = new SparkSQLCLIDriver cli.setHiveVariables(oproc.getHiveVariables) @@ -157,12 +176,8 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Execute -i init files (always in silent mode) cli.processInitFiles(sessionState) - // Respect the configurations set by --hiveconf from the command line - // (based on Hive's CliDriver). - val it = sessionState.getOverriddenConfigurations.entrySet().iterator() - while (it.hasNext) { - val kv = it.next() - SparkSQLEnv.sqlContext.setConf(kv.getKey, kv.getValue) + newHiveConf.foreach { kv => + SparkSQLEnv.sqlContext.setConf(kv._1, kv._2) } if (sessionState.execString != null) { @@ -272,7 +287,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val sessionState = SessionState.get().asInstanceOf[CliSessionState] - private val LOG = LogFactory.getLog("CliDriver") + private val LOG = LogFactory.getLog(classOf[SparkSQLCLIDriver]) private val console = new SessionState.LogHelper(LOG) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 1b17a9a56e5b..ad1f5eb9ca3a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.{AbstractService, Service, ServiceException} @@ -47,6 +48,7 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null + var httpUGI: UserGroupInformation = null if (UserGroupInformation.isSecurityEnabled) { try { @@ -57,6 +59,20 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC case e @ (_: IOException | _: LoginException) => throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) } + + // Try creating spnego UGI if it is configured. + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_PRINCIPAL).trim + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_KEYTAB).trim + if (principal.nonEmpty && keyTabFile.nonEmpty) { + try { + httpUGI = HiveAuthFactory.loginFromSpnegoKeytabAndReturnUGI(hiveConf) + setSuperField(this, "httpUGI", httpUGI) + } catch { + case e: IOException => + throw new ServiceException("Unable to login to spnego with given principal " + + s"$principal and keytab $keyTabFile: $e", e) + } + } } initCompositeService(hiveConf) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 0d5dc7af5f52..677590217344 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext) @@ -60,7 +60,9 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont try { context.sparkContext.setJobDescription(command) val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) - hiveResponse = execution.hiveResultString() + hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) { + execution.hiveResultString() + } tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) } catch { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 2e0fa1ef77f8..f517bffccdf3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -72,7 +72,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = listener.getExecutionList + val dataRows = listener.getExecutionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -103,7 +103,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } val content = -
    SQL Statistics
    ++ +
    SQL Statistics ({numStatement})
    ++
      {table.getOrElse("No statistics have been generated yet.")} @@ -142,7 +142,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { - val dataRows = sessionList + val dataRows = sessionList.sortBy(_.startTimestamp).reverse val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { @@ -164,7 +164,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } val content = -
      Session Statistics
      ++ +
      Session Statistics ({numBatches})
      ++
        {table.getOrElse("No statistics have been generated yet.")} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index f39e9dcd3a5b..5cd2fdf6437c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -39,7 +39,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val content = @@ -65,7 +66,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) val timeSinceStart = System.currentTimeMillis() - startTime.getTime
        • - Started at: {startTime.toString} + Started at: {formatDate(startTime)}
        • Time since start: {formatDurationVerbose(timeSinceStart)} @@ -146,42 +147,6 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) {errorSummary}{details} } - /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { - val sessionList = listener.getSessionList - val numBatches = sessionList.size - val table = if (numBatches > 0) { - val dataRows = - sessionList.sortBy(_.startTimestamp).reverse.map ( session => - Seq( - session.userName, - session.ip, - session.sessionId, - formatDate(session.startTimestamp), - formatDate(session.finishTimestamp), - formatDurationOption(Some(session.totalTime)), - session.totalExecution.toString - ) - ).toSeq - val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", - "Total Execute") - Some(listingTable(headerRow, dataRows)) - } else { - None - } - - val content = -
          Session Statistics
          ++ -
          -
            - {table.getOrElse("No statistics have been generated yet.")} -
          -
          - - content - } - - /** * Returns a human-readable string representing a duration such as "5 second 35 ms" */ @@ -197,4 +162,3 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) } } - diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala new file mode 100644 index 000000000000..3f135cc86498 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.util.Utils + +class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { + + test("lazily initialize Hive client") { + val spark = SparkSession.builder() + .appName("HiveMetastoreLazyInitializationSuite") + .master("local[2]") + .enableHiveSupport() + .config("spark.hadoop.hive.metastore.uris", "thrift://127.0.0.1:11111") + .getOrCreate() + val originalLevel = org.apache.log4j.Logger.getRootLogger().getLevel + try { + // Avoid outputting a lot of expected warning logs + spark.sparkContext.setLogLevel("error") + + // We should be able to run Spark jobs without Hive client. + assert(spark.sparkContext.range(0, 1).count() === 1) + + // Make sure that we are not using the local derby metastore. + val exceptionString = Utils.exceptionString(intercept[AnalysisException] { + spark.sql("show tables") + }) + for (msg <- Seq( + "show tables", + "Could not connect to meta store", + "org.apache.thrift.transport.TTransportException", + "Connection refused")) { + exceptionString.contains(msg) + } + } finally { + spark.sparkContext.setLogLevel(originalLevel.toString) + spark.stop() + } + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index d3cec11bd756..933fd7369380 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -283,4 +283,17 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "SET conf3;" -> "conftest" ) } + + test("SPARK-21451: spark.sql.warehouse.dir should respect options in --hiveconf") { + runCliWithin(1.minute)("set spark.sql.warehouse.dir;" -> warehousePath.getAbsolutePath) + } + + test("SPARK-21451: Apply spark.hadoop.* configurations") { + val tmpDir = Utils.createTempDir(namePrefix = "SPARK-21451") + runCliWithin( + 1.minute, + Seq(s"--conf", s"spark.hadoop.${ConfVars.METASTOREWAREHOUSE}=$tmpDir"))( + "set spark.sql.warehouse.dir;" -> tmpDir.getAbsolutePath) + tmpDir.delete() + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala new file mode 100644 index 000000000000..5f9ea4d26790 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.cli.CliSessionState +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.hive.HiveUtils + +class HiveCliSessionStateSuite extends SparkFunSuite { + + def withSessionClear(f: () => Unit): Unit = { + try f finally SessionState.detachSession() + } + + test("CliSessionState will be reused") { + withSessionClear { () => + val hiveConf = new HiveConf(classOf[SessionState]) + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false).foreach { + case (key, value) => hiveConf.set(key, value) + } + val sessionState: SessionState = new CliSessionState(hiveConf) + SessionState.start(sessionState) + val s1 = SessionState.get + val sparkConf = new SparkConf() + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val s2 = HiveUtils.newClientForMetadata(sparkConf, hadoopConf).getState + assert(s1 === s2) + assert(s2.isInstanceOf[CliSessionState]) + } + } + + test("SessionState will not be reused") { + withSessionClear { () => + val sparkConf = new SparkConf() + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false).foreach { + case (key, value) => hadoopConf.set(key, value) + } + val hiveClient = HiveUtils.newClientForMetadata(sparkConf, hadoopConf) + val s1 = hiveClient.getState + val s2 = hiveClient.newSession().getState + assert(s1 !== s2) + } + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b6215bde6bf0..4997d7f96afa 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -33,11 +33,9 @@ import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.apache.hive.service.auth.PlainSaslHelper -import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.{FetchOrientation, FetchType, GetInfoType} import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient -import org.apache.hive.service.cli.FetchOrientation -import org.apache.hive.service.cli.FetchType import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0a53aaca404e..45791c69b4cb 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -39,7 +39,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - private val originalConvertMetastoreOrc = TestHive.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone @@ -58,9 +57,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) - // Ensures that the plans generation use metastore relation and not OrcRelation - // Was done because SqlBuilder does not work with plans having logical relation - TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests @@ -76,7 +72,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 09dcc4055e00..66fad85ea026 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -57,6 +57,20 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + ${project.version} + test + org.apache.spark spark-tags_${scala.binary.version} @@ -163,22 +177,17 @@ libfb303 - org.scalacheck - scalacheck_${scala.binary.version} - test + org.apache.derby + derby - org.apache.spark - spark-sql_${scala.binary.version} - test-jar - ${project.version} + org.scala-lang + scala-compiler test - org.apache.spark - spark-catalyst_${scala.binary.version} - test-jar - ${project.version} + org.scalacheck + scalacheck_${scala.binary.version} test diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ba48facff293..96dc983b0bfc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -114,7 +114,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * should interpret these special data source properties and restore the original table metadata * before returning it. */ - private def getRawTable(db: String, table: String): CatalogTable = withClient { + private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient { client.getTable(db, table) } @@ -224,45 +224,43 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat throw new TableAlreadyExistsException(db = db, table = table) } - if (tableDefinition.tableType == VIEW) { - client.createTable(tableDefinition, ignoreIfExists) + // Ideally we should not create a managed table with location, but Hive serde table can + // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have + // to create the table directory and write out data before we create this table, to avoid + // exposing a partial written table. + val needDefaultTableLocation = tableDefinition.tableType == MANAGED && + tableDefinition.storage.locationUri.isEmpty + + val tableLocation = if (needDefaultTableLocation) { + Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) } else { - // Ideally we should not create a managed table with location, but Hive serde table can - // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have - // to create the table directory and write out data before we create this table, to avoid - // exposing a partial written table. - val needDefaultTableLocation = tableDefinition.tableType == MANAGED && - tableDefinition.storage.locationUri.isEmpty - - val tableLocation = if (needDefaultTableLocation) { - Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) - } else { - tableDefinition.storage.locationUri - } + tableDefinition.storage.locationUri + } - if (DDLUtils.isHiveTable(tableDefinition)) { - val tableWithDataSourceProps = tableDefinition.copy( - // We can't leave `locationUri` empty and count on Hive metastore to set a default table - // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default - // table location for tables in default database, while we expect to use the location of - // default database. - storage = tableDefinition.storage.copy(locationUri = tableLocation), - // Here we follow data source tables and put table metadata like table schema, partition - // columns etc. in table properties, so that we can work around the Hive metastore issue - // about not case preserving and make Hive serde table support mixed-case column names. - properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) - client.createTable(tableWithDataSourceProps, ignoreIfExists) - } else { - createDataSourceTable( - tableDefinition.withNewStorage(locationUri = tableLocation), - ignoreIfExists) - } + if (DDLUtils.isDatasourceTable(tableDefinition)) { + createDataSourceTable( + tableDefinition.withNewStorage(locationUri = tableLocation), + ignoreIfExists) + } else { + val tableWithDataSourceProps = tableDefinition.copy( + // We can't leave `locationUri` empty and count on Hive metastore to set a default table + // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default + // table location for tables in default database, while we expect to use the location of + // default database. + storage = tableDefinition.storage.copy(locationUri = tableLocation), + // Here we follow data source tables and put table metadata like table schema, partition + // columns etc. in table properties, so that we can work around the Hive metastore issue + // about not case preserving and make Hive serde table and view support mixed-case column + // names. + properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) + client.createTable(tableWithDataSourceProps, ignoreIfExists) } } private def createDataSourceTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = { // data source table always have a provider, it's guaranteed by `DDLUtils.isDatasourceTable`. val provider = table.provider.get + val options = new SourceOptions(table.storage.properties) // To work around some hive metastore issues, e.g. not case-preserving, bad decimal type // support, no column nullability, etc., we should do some extra works before saving table @@ -328,11 +326,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val qualifiedTableName = table.identifier.quotedString val maybeSerde = HiveSerDe.sourceToSerDe(provider) - val skipHiveMetadata = table.storage.properties - .getOrElse("skipHiveMetadata", "false").toBoolean val (hiveCompatibleTable, logMessage) = maybeSerde match { - case _ if skipHiveMetadata => + case _ if options.skipHiveMetadata => val message = s"Persisting data source table $qualifiedTableName into Hive metastore in" + "Spark SQL specific format, which is NOT compatible with Hive." @@ -389,15 +385,24 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * can be used as table properties later. */ private def tableMetaToTableProps(table: CatalogTable): mutable.Map[String, String] = { + tableMetaToTableProps(table, table.schema) + } + + private def tableMetaToTableProps( + table: CatalogTable, + schema: StructType): mutable.Map[String, String] = { val partitionColumns = table.partitionColumnNames val bucketSpec = table.bucketSpec val properties = new mutable.HashMap[String, String] + + properties.put(CREATED_SPARK_VERSION, table.createVersion) + // Serialized JSON schema string may be too long to be stored into a single metastore table // property. In this case, we split the JSON string and store each part as a separate table // property. val threshold = conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) - val schemaJsonString = table.schema.json + val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq properties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString) @@ -507,7 +512,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat identifier = TableIdentifier(newName, Some(db)), storage = storageWithNewPath) - client.alterTable(oldName, newTable) + client.alterTable(db, oldName, newTable) } private def getLocationFromStorageProps(table: CatalogTable): Option[String] = { @@ -527,7 +532,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat /** * Alter a table whose name that matches the one specified in `tableDefinition`, - * assuming the table exists. + * assuming the table exists. This method does not change the properties for data source and + * statistics. * * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. @@ -538,30 +544,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, tableDefinition.identifier.table) verifyTableProperties(tableDefinition) - // convert table statistics to properties so that we can persist them through hive api - val withStatsProps = if (tableDefinition.stats.isDefined) { - val stats = tableDefinition.stats.get - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) - if (stats.rowCount.isDefined) { - statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() - } - val colNameTypeMap: Map[String, DataType] = - tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap - stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) - } - } - tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) - } else { - tableDefinition - } - if (tableDefinition.tableType == VIEW) { - client.alterTable(withStatsProps) + client.alterTable(tableDefinition) } else { - val oldTableDef = getRawTable(db, withStatsProps.identifier.table) + val oldTableDef = getRawTable(db, tableDefinition.identifier.table) val newStorage = if (DDLUtils.isHiveTable(tableDefinition)) { tableDefinition.storage @@ -611,12 +597,16 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat TABLE_PARTITION_PROVIDER -> TABLE_PARTITION_PROVIDER_FILESYSTEM } - // Sets the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, - // to retain the spark specific format if it is. Also add old data source properties to table - // properties, to retain the data source table format. - val oldDataSourceProps = oldTableDef.properties.filter(_._1.startsWith(DATASOURCE_PREFIX)) - val newTableProps = oldDataSourceProps ++ withStatsProps.properties + partitionProviderProp - val newDef = withStatsProps.copy( + // Add old data source properties to table properties, to retain the data source table format. + // Add old stats properties to table properties, to retain spark's stats. + // Set the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, + // to retain the spark specific format if it is. + val propsFromOldTable = oldTableDef.properties.filter { case (k, v) => + k.startsWith(DATASOURCE_PREFIX) || k.startsWith(STATISTICS_PREFIX) || + k.startsWith(CREATED_SPARK_VERSION) + } + val newTableProps = propsFromOldTable ++ tableDefinition.properties + partitionProviderProp + val newDef = tableDefinition.copy( storage = newStorage, schema = oldTableDef.schema, partitionColumnNames = oldTableDef.partitionColumnNames, @@ -630,29 +620,58 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient { requireTableExists(db, table) val rawTable = getRawTable(db, table) - val withNewSchema = rawTable.copy(schema = schema) - verifyColumnNames(withNewSchema) // Add table metadata such as table schema, partition columns, etc. to table properties. - val updatedTable = withNewSchema.copy( - properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema)) - try { - client.alterTable(updatedTable) - } catch { - case NonFatal(e) => - val warningMessage = - s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + - "compatible way. Updating Hive metastore in Spark SQL specific format." - logWarning(warningMessage, e) - client.alterTable(updatedTable.copy(schema = updatedTable.partitionSchema)) + val updatedProperties = rawTable.properties ++ tableMetaToTableProps(rawTable, schema) + val withNewSchema = rawTable.copy(properties = updatedProperties, schema = schema) + verifyColumnNames(withNewSchema) + + if (isDatasourceTable(rawTable)) { + // For data source tables, first try to write it with the schema set; if that does not work, + // try again with updated properties and the partition schema. This is a simplified version of + // what createDataSourceTable() does, and may leave the table in a state unreadable by Hive + // (for example, the schema does not match the data source schema, or does not match the + // storage descriptor). + try { + client.alterTable(withNewSchema) + } catch { + case NonFatal(e) => + val warningMessage = + s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + + "compatible way. Updating Hive metastore in Spark SQL specific format." + logWarning(warningMessage, e) + client.alterTable(withNewSchema.copy(schema = rawTable.partitionSchema)) + } + } else { + client.alterTable(withNewSchema) } } - override def getTable(db: String, table: String): CatalogTable = withClient { - restoreTableMetadata(getRawTable(db, table)) + override def alterTableStats( + db: String, + table: String, + stats: Option[CatalogStatistics]): Unit = withClient { + requireTableExists(db, table) + val rawTable = getRawTable(db, table) + + // For datasource tables and hive serde tables created by spark 2.1 or higher, + // the data schema is stored in the table properties. + val schema = restoreTableMetadata(rawTable).schema + + // convert table statistics to properties so that we can persist them through hive client + var statsProperties = + if (stats.isDefined) { + statsToProperties(stats.get, schema) + } else { + new mutable.HashMap[String, String]() + } + + val oldTableNonStatsProps = rawTable.properties.filterNot(_._1.startsWith(STATISTICS_PREFIX)) + val updatedTable = rawTable.copy(properties = oldTableNonStatsProps ++ statsProperties) + client.alterTable(updatedTable) } - override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient { - client.getTableOption(db, table).map(restoreTableMetadata) + override def getTable(db: String, table: String): CatalogTable = withClient { + restoreTableMetadata(getRawTable(db, table)) } /** @@ -669,55 +688,55 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat var table = inputTable - if (table.tableType != VIEW) { - table.properties.get(DATASOURCE_PROVIDER) match { - // No provider in table properties, which means this is a Hive serde table. - case None => - table = restoreHiveSerdeTable(table) - - // This is a regular data source table. - case Some(provider) => - table = restoreDataSourceTable(table, provider) - } - } - - // construct Spark's statistics from information in Hive metastore - val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + table.properties.get(DATASOURCE_PROVIDER) match { + case None if table.tableType == VIEW => + // If this is a view created by Spark 2.2 or higher versions, we should restore its schema + // from table properties. + if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) { + table = table.copy(schema = getSchemaFromTableProperties(table)) + } - if (statsProps.nonEmpty) { - val colStats = new mutable.HashMap[String, ColumnStat] + // No provider in table properties, which means this is a Hive serde table. + case None => + table = restoreHiveSerdeTable(table) - // For each column, recover its column stats. Note that this is currently a O(n^2) operation, - // but given the number of columns it usually not enormous, this is probably OK as a start. - // If we want to map this a linear operation, we'd need a stronger contract between the - // naming convention used for serialization. - table.schema.foreach { field => - if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { - // If "version" field is defined, then the column stat is defined. - val keyPrefix = columnStatKeyPropName(field.name, "") - val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => - (k.drop(keyPrefix.length), v) - } + // This is a regular data source table. + case Some(provider) => + table = restoreDataSourceTable(table, provider) + } - ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach { - colStat => colStats += field.name -> colStat - } - } - } + // Restore version info + val version: String = table.properties.getOrElse(CREATED_SPARK_VERSION, "2.2 or prior") - table = table.copy( - stats = Some(CatalogStatistics( - sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)), - rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), - colStats = colStats.toMap))) + // Restore Spark's statistics from information in Metastore. + val restoredStats = + statsFromProperties(table.properties, table.identifier.table, table.schema) + if (restoredStats.isDefined) { + table = table.copy(stats = restoredStats) } // Get the original table properties as defined by the user. table.copy( + createVersion = version, properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }) } + // Reorder table schema to put partition columns at the end. Before Spark 2.2, the partition + // columns are not put at the end of schema. We need to reorder it when reading the schema + // from the table properties. + private def reorderSchema(schema: StructType, partColumnNames: Seq[String]): StructType = { + val partitionFields = partColumnNames.map { partCol => + schema.find(_.name == partCol).getOrElse { + throw new AnalysisException("The metadata is corrupted. Unable to find the " + + s"partition column names from the schema. schema: ${schema.catalogString}. " + + s"Partition columns: ${partColumnNames.mkString("[", ", ", "]")}") + } + } + StructType(schema.filterNot(partitionFields.contains) ++ partitionFields) + } + private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = { + val options = new SourceOptions(table.storage.properties) val hiveTable = table.copy( provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true) @@ -726,10 +745,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // schema from table properties. if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) { val schemaFromTableProps = getSchemaFromTableProperties(table) - if (DataType.equalsIgnoreCaseAndNullability(schemaFromTableProps, table.schema)) { + val partColumnNames = getPartitionColumnsFromTableProperties(table) + val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) + + if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema) || + options.respectSparkSchema) { hiveTable.copy( - schema = schemaFromTableProps, - partitionColumnNames = getPartitionColumnsFromTableProperties(table), + schema = reorderedSchema, + partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table)) } else { // Hive metastore may change the table schema, e.g. schema inference. If the table @@ -759,11 +782,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) + val schemaFromTableProps = getSchemaFromTableProperties(table) + val partColumnNames = getPartitionColumnsFromTableProperties(table) + val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) + table.copy( provider = Some(provider), storage = storageWithLocation, - schema = getSchemaFromTableProperties(table), - partitionColumnNames = getPartitionColumnsFromTableProperties(table), + schema = reorderedSchema, + partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table), tracksPartitionsInCatalog = partitionProvider == Some(TABLE_PARTITION_PROVIDER_CATALOG)) } @@ -991,17 +1018,92 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat currentFullPath } + private def statsToProperties( + stats: CatalogStatistics, + schema: StructType): Map[String, String] = { + + var statsProperties: Map[String, String] = + Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) + if (stats.rowCount.isDefined) { + statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() + } + + val colNameTypeMap: Map[String, DataType] = + schema.fields.map(f => (f.name, f.dataType)).toMap + stats.colStats.foreach { case (colName, colStat) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } + } + + statsProperties + } + + private def statsFromProperties( + properties: Map[String, String], + table: String, + schema: StructType): Option[CatalogStatistics] = { + + val statsProps = properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + if (statsProps.isEmpty) { + None + } else { + + val colStats = new mutable.HashMap[String, ColumnStat] + + // For each column, recover its column stats. Note that this is currently a O(n^2) operation, + // but given the number of columns it usually not enormous, this is probably OK as a start. + // If we want to map this a linear operation, we'd need a stronger contract between the + // naming convention used for serialization. + schema.foreach { field => + if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { + // If "version" field is defined, then the column stat is defined. + val keyPrefix = columnStatKeyPropName(field.name, "") + val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => + (k.drop(keyPrefix.length), v) + } + + ColumnStat.fromMap(table, field, colStatMap).foreach { + colStat => colStats += field.name -> colStat + } + } + } + + Some(CatalogStatistics( + sizeInBytes = BigInt(statsProps(STATISTICS_TOTAL_SIZE)), + rowCount = statsProps.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + colStats = colStats.toMap)) + } + } + override def alterPartitions( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withClient { val lowerCasedParts = newParts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + + val rawTable = getRawTable(db, table) + + // For datasource tables and hive serde tables created by spark 2.1 or higher, + // the data schema is stored in the table properties. + val schema = restoreTableMetadata(rawTable).schema + + // convert partition statistics to properties so that we can persist them through hive api + val withStatsProps = lowerCasedParts.map(p => { + if (p.stats.isDefined) { + val statsProperties = statsToProperties(p.stats.get, schema) + p.copy(parameters = p.parameters ++ statsProperties) + } else { + p + } + }) + // Note: Before altering table partitions in Hive, you *must* set the current database // to the one that contains the table of interest. Otherwise you will end up with the // most helpful error message ever: "Unable to alter partition. alter is not possible." // See HIVE-2742 for more detail. client.setCurrentDatabase(db) - client.alterPartitions(db, table, lowerCasedParts) + client.alterPartitions(db, table, withStatsProps) } override def getPartition( @@ -1009,7 +1111,34 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, spec: TablePartitionSpec): CatalogTablePartition = withClient { val part = client.getPartition(db, table, lowerCasePartitionSpec(spec)) - part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + restorePartitionMetadata(part, getTable(db, table)) + } + + /** + * Restores partition metadata from the partition properties. + * + * Reads partition-level statistics from partition properties, puts these + * into [[CatalogTablePartition#stats]] and removes these special entries + * from the partition properties. + */ + private def restorePartitionMetadata( + partition: CatalogTablePartition, + table: CatalogTable): CatalogTablePartition = { + val restoredSpec = restorePartitionSpec(partition.spec, table.partitionColumnNames) + + // Restore Spark's statistics from information in Metastore. + // Note: partition-level statistics were introduced in 2.3. + val restoredStats = + statsFromProperties(partition.parameters, table.identifier.table, table.schema) + if (restoredStats.isDefined) { + partition.copy( + spec = restoredSpec, + stats = restoredStats, + parameters = partition.parameters.filterNot { + case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }) + } else { + partition.copy(spec = restoredSpec) + } } /** @@ -1020,7 +1149,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, spec: TablePartitionSpec): Option[CatalogTablePartition] = withClient { client.getPartitionOption(db, table, lowerCasePartitionSpec(spec)).map { part => - part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + restorePartitionMetadata(part, getTable(db, table)) } } @@ -1051,9 +1180,19 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient { val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table)) - client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part => + val res = client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) } + + partialSpec match { + // This might be a bug of Hive: When the partition value inside the partial partition spec + // contains dot, and we ask Hive to list partitions w.r.t. the partial partition spec, Hive + // treats dot as matching any single character and may return more partitions than we + // expected. Here we do an extra filter to drop unexpected partitions. + case Some(spec) if spec.exists(_._2.contains(".")) => + res.filter(p => isPartialPartitionSpec(spec, p.spec)) + case _ => res + } } override def listPartitionsByFilter( @@ -1095,6 +1234,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropFunction(db, name) } + override protected def doAlterFunction( + db: String, funcDefinition: CatalogFunction): Unit = withClient { + requireDbExists(db) + val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) + requireFunctionExists(db, functionName) + val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) + client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) + } + override protected def doRenameFunction( db: String, oldName: String, @@ -1147,6 +1295,8 @@ object HiveExternalCatalog { val TABLE_PARTITION_PROVIDER_CATALOG = "catalog" val TABLE_PARTITION_PROVIDER_FILESYSTEM = "filesystem" + val CREATED_SPARK_VERSION = SPARK_SQL_PREFIX + "create.version" + /** * Returns the fully qualified name used in table properties for a particular column stat. * For example, for column "mycol", and "min" stat, this should return @@ -1217,4 +1367,14 @@ object HiveExternalCatalog { getColumnNamesByType(metadata.properties, "sort", "sorting columns")) } } + + /** + * Detects a data source table. This checks both the table provider and the table properties, + * unlike DDLUtils which just checks the former. + */ + private[spark] def isDatasourceTable(table: CatalogTable): Boolean = { + val provider = table.provider.orElse(table.properties.get(DATASOURCE_PROVIDER)) + provider.isDefined && provider != Some(DDLUtils.HIVE_PROVIDER) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b98066cb76c..f0f2c493498b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.types._ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { // these are def_s and not val/lazy val since the latter would introduce circular references private def sessionState = sparkSession.sessionState - private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache + private def catalogProxy = sparkSession.sessionState.catalog import HiveMetastoreCatalog._ /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ @@ -61,7 +61,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val key = QualifiedTableName( table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, table.table.toLowerCase) - tableRelationCache.getIfPresent(key) + catalogProxy.getCachedTable(key) } private def getCached( @@ -71,9 +71,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log expectedFileFormat: Class[_ <: FileFormat], partitionSchema: Option[StructType]): Option[LogicalRelation] = { - tableRelationCache.getIfPresent(tableIdentifier) match { + catalogProxy.getCachedTable(tableIdentifier) match { case null => None // Cache miss - case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) => + case logical @ LogicalRelation(relation: HadoopFsRelation, _, _, _) => val cachedRelationFileFormatClass = relation.fileFormat.getClass expectedFileFormat match { @@ -92,27 +92,27 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(logical) } else { // If the cached relation is not updated, we invalidate it right away. - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case _ => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a ${relation.fileFormat} from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case other => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a $other from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } } def convertToLogicalRelation( - relation: CatalogRelation, + relation: HiveTableRelation, options: Map[String, String], fileFormatClass: Class[_ <: FileFormat], fileType: String): LogicalRelation = { @@ -154,7 +154,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong + val sizeInBytes = relation.stats.sizeInBytes.toLong val fileIndex = { val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) if (lazyPruningEnabled) { @@ -171,12 +171,11 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log location = fileIndex, partitionSchema = partitionSchema, dataSchema = dataSchema, - // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, fileFormat = fileFormat, options = options)(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } @@ -199,20 +198,19 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log sparkSession = sparkSession, paths = rootPath.toString :: Nil, userSpecifiedSchema = Option(dataSchema), - // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, options = options, className = fileType).resolveRelation(), table = updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } logicalRelation }) } - // The inferred schema may have different filed names as the table schema, we should respect + // The inferred schema may have different field names as the table schema, we should respect // it, but also respect the exprId in table relation output. assert(result.output.length == relation.output.length && result.output.zip(relation.output).forall { case (a1, a2) => a1.dataType == a2.dataType }) @@ -223,7 +221,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } private def inferIfNeeded( - relation: CatalogRelation, + relation: HiveTableRelation, options: Map[String, String], fileFormat: FileFormat, fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 377d4f2473c5..b256ffc27b19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -30,14 +30,12 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} -import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( @@ -58,55 +56,52 @@ private[sql] class HiveSessionCatalog( parser, functionResourceLoader) { - override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { - makeFunctionBuilder(funcName, Utils.classForName(className)) - } - /** - * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * Constructs a [[Expression]] based on the provided class that represents a function. + * + * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { - // When we instantiate hive UDF wrapper class, we may throw exception if the input - // expressions don't satisfy the hive UDF, such as type mismatch, input number - // mismatch, etc. Here we catch the exception and throw AnalysisException instead. - (children: Seq[Expression]) => { + override def makeFunctionExpression( + name: String, + clazz: Class[_], + input: Seq[Expression]): Expression = { + + Try(super.makeFunctionExpression(name, clazz, input)).getOrElse { + var udfExpr: Option[Expression] = None try { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. if (classOf[UDF].isAssignableFrom(clazz)) { - val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children) - udf.dataType // Force it to check input data types. - udf + udfExpr = Some(HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { - val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children) - udf.dataType // Force it to check input data types. - udf + udfExpr = Some(HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children) - udaf.dataType // Force it to check input data types. - udaf + udfExpr = Some(HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[UDAF].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction( + udfExpr = Some(HiveUDAFFunction( name, new HiveFunctionWrapper(clazz.getName), - children, - isUDAFBridgeRequired = true) - udaf.dataType // Force it to check input data types. - udaf + input, + isUDAFBridgeRequired = true)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) - udtf.elementSchema // Force it to check input data types. - udtf - } else { - throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") + udfExpr = Some(HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.asInstanceOf[HiveGenericUDTF].elementSchema // Force it to check data types. } } catch { - case ae: AnalysisException => - throw ae case NonFatal(e) => val analysisException = - new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}': $e") + new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") analysisException.setStackTrace(e.getStackTrace) throw analysisException } + udfExpr.getOrElse { + throw new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") + } } } @@ -129,7 +124,7 @@ private[sql] class HiveSessionCatalog( Try(super.lookupFunction(funcName, children)) match { case Success(expr) => expr case Failure(error) => - if (functionRegistry.functionExists(funcName.unquotedString)) { + if (functionRegistry.functionExists(funcName)) { // If the function actually exists in functionRegistry, it means that there is an // error when we create the Expression using the given children. // We need to throw the original exception. @@ -140,7 +135,7 @@ private[sql] class HiveSessionCatalog( // Hive is case insensitive. val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) if (!hiveFunctions.contains(functionName)) { - failFunctionLookup(funcName.unquotedString) + failFunctionLookup(funcName) } // TODO: Remove this fallback path once we implement the list of fallback functions @@ -148,12 +143,12 @@ private[sql] class HiveSessionCatalog( val functionInfo = { try { Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse( - failFunctionLookup(funcName.unquotedString)) + failFunctionLookup(funcName)) } catch { // If HiveFunctionRegistry.getFunctionInfo throws an exception, // we are failing to load a Hive builtin function, which means that // the given function is not a Hive builtin function. - case NonFatal(e) => failFunctionLookup(funcName.unquotedString) + case NonFatal(e) => failFunctionLookup(funcName) } } val className = functionInfo.getFunctionClass.getName @@ -161,9 +156,9 @@ private[sql] class HiveSessionCatalog( FunctionIdentifier(functionName.toLowerCase(Locale.ROOT), database) val func = CatalogFunction(functionIdentifier, className, Nil) // Put this Hive built-in function to our function registry. - registerFunction(func, ignoreIfExists = false) + registerFunction(func, overrideIfExists = false) // Now, we need to create the Expression. - functionRegistry.lookupFunction(functionName, children) + functionRegistry.lookupFunction(functionIdentifier, children) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index e16c9e46b772..92cb4ef11c9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -69,22 +69,23 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override protected def analyzer: Analyzer = new Analyzer(catalog, conf) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new ResolveHiveSerdeTable(session) +: - new FindDataSourceTable(session) +: - new ResolveSQLOnFile(session) +: - customResolutionRules + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = new DetermineTableStats(session) +: - RelationConversions(conf, catalog) +: - PreprocessTableCreation(session) +: - PreprocessTableInsertion(conf) +: - DataSourceAnalysis(conf) +: - HiveAnalysis +: - customPostHocResolutionRules + RelationConversions(conf, catalog) +: + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + HiveAnalysis +: + customPostHocResolutionRules override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: - customCheckRules + PreReadCheck +: + customCheckRules } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 09a5eda6e543..805b3171cdaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -21,13 +21,13 @@ import java.io.IOException import java.util.Locale import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan, + ScriptTransformation} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} @@ -116,23 +116,10 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case relation: CatalogRelation + case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta - // TODO: check if this estimate is valid for tables after partition pruning. - // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be - // relatively cheap if parameters for the table are populated into the metastore. - // Besides `totalSize`, there are also `numFiles`, `numRows`, `rawDataSize` keys - // (see StatsSetupConst in Hive) that we can look at in the future. - // When table is external,`totalSize` is always zero, which will influence join strategy - // so when `totalSize` is zero, use `rawDataSize` instead. - val totalSize = table.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) - val rawDataSize = table.properties.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong) - val sizeInBytes = if (totalSize.isDefined && totalSize.get > 0) { - totalSize.get - } else if (rawDataSize.isDefined && rawDataSize.get > 0) { - rawDataSize.get - } else if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { + val sizeInBytes = if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { try { val hadoopConf = session.sessionState.newHadoopConf() val tablePath = new Path(table.location) @@ -160,15 +147,24 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { */ object HiveAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case InsertIntoTable(relation: CatalogRelation, partSpec, query, overwrite, ifNotExists) - if DDLUtils.isHiveTable(relation.tableMeta) => - InsertIntoHiveTable(relation.tableMeta, partSpec, query, overwrite, ifNotExists) + case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) + if DDLUtils.isHiveTable(r.tableMeta) => + InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => + DDLUtils.checkDataSchemaFieldNames(tableDesc) CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => + DDLUtils.checkDataSchemaFieldNames(tableDesc) CreateHiveTableAsSelectCommand(tableDesc, query, mode) + + case InsertIntoDir(isLocal, storage, provider, child, overwrite) + if DDLUtils.isHiveTable(provider) => + val outputPath = new Path(storage.locationUri.get) + if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath) + + InsertIntoHiveDirCommand(isLocal, storage, child, overwrite) } } @@ -184,13 +180,13 @@ object HiveAnalysis extends Rule[LogicalPlan] { case class RelationConversions( conf: SQLConf, sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { - private def isConvertible(relation: CatalogRelation): Boolean = { + private def isConvertible(relation: HiveTableRelation): Boolean = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } - private def convert(relation: CatalogRelation): LogicalRelation = { + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) if (serde.contains("parquet")) { val options = Map(ParquetOptions.MERGE_SCHEMA -> @@ -207,14 +203,14 @@ case class RelationConversions( override def apply(plan: LogicalPlan): LogicalPlan = { plan transformUp { // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) + case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists) // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && isConvertible(r) => - InsertIntoTable(convert(r), partition, query, overwrite, ifNotExists) + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && isConvertible(r) => + InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists) // Read path - case relation: CatalogRelation + case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => convert(relation) } @@ -242,7 +238,7 @@ private[hive] trait HiveStrategies { */ object HiveTableScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, predicates, relation: CatalogRelation) => + case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. val partitionKeyIds = AttributeSet(relation.partitionCols) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 3de60c7fc131..80b9a3dc9605 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -24,18 +24,20 @@ import java.sql.Timestamp import java.util.Locale import java.util.concurrent.TimeUnit -import scala.collection.mutable.HashMap import scala.collection.JavaConverters._ +import scala.collection.mutable.HashMap import scala.language.implicitConversions import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.CatalogTable @@ -86,8 +88,8 @@ private[spark] object HiveUtils extends Logging { .createWithDefault("builtin") val CONVERT_METASTORE_PARQUET = buildConf("spark.sql.hive.convertMetastoreParquet") - .doc("When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + - "the built in support.") + .doc("When set to true, the built-in Parquet reader and writer are used to process " + + "parquet tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf .createWithDefault(true) @@ -101,8 +103,8 @@ private[spark] object HiveUtils extends Logging { val CONVERT_METASTORE_ORC = buildConf("spark.sql.hive.convertMetastoreOrc") .internal() - .doc("When set to false, Spark SQL will use the Hive SerDe for ORC tables instead of " + - "the built in support.") + .doc("When set to true, the built-in ORC reader and writer are used to process " + + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf .createWithDefault(false) @@ -174,9 +176,9 @@ private[spark] object HiveUtils extends Logging { } /** - * Configurations needed to create a [[HiveClient]]. + * Change time configurations needed to create a [[HiveClient]] into unified [[Long]] format. */ - private[hive] def hiveClientConfigurations(hadoopConf: Configuration): Map[String, String] = { + private[hive] def formatTimeVarsForHiveClient(hadoopConf: Configuration): Map[String, String] = { // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- // compatibility when users are trying to connecting to a Hive metastore of lower version, @@ -229,6 +231,22 @@ private[spark] object HiveUtils extends Logging { }.toMap } + /** + * Check current Thread's SessionState type + * @return true when SessionState.get returns an instance of CliSessionState, + * false when it gets non-CliSessionState instance or null + */ + def isCliSessionState(): Boolean = { + val state = SessionState.get + var temp: Class[_] = if (state != null) state.getClass else null + var found = false + while (temp != null && !found) { + found = temp.getName == "org.apache.hadoop.hive.cli.CliSessionState" + temp = temp.getSuperclass + } + found + } + /** * Create a [[HiveClient]] used for execution. * @@ -245,7 +263,7 @@ private[spark] object HiveUtils extends Logging { val loader = new IsolatedClientLoader( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), sparkConf = conf, - execJars = Seq(), + execJars = Seq.empty, hadoopConf = hadoopConf, config = newTemporaryConfiguration(useInMemoryDerby = true), isolationOn = false, @@ -262,7 +280,7 @@ private[spark] object HiveUtils extends Logging { protected[hive] def newClientForMetadata( conf: SparkConf, hadoopConf: Configuration): HiveClient = { - val configurations = hiveClientConfigurations(hadoopConf) + val configurations = formatTimeVarsForHiveClient(hadoopConf) newClientForMetadata(conf, hadoopConf, configurations) } @@ -312,7 +330,7 @@ private[spark] object HiveUtils extends Logging { hadoopConf = hadoopConf, execJars = jars.toSeq, config = configurations, - isolationOn = true, + isolationOn = !isCliSessionState(), barrierPrefixes = hiveMetastoreBarrierPrefixes, sharedPrefixes = hiveMetastoreSharedPrefixes) } else if (hiveMetastoreJars == "maven") { @@ -404,6 +422,13 @@ private[spark] object HiveUtils extends Logging { propMap.put(ConfVars.METASTORE_EVENT_LISTENERS.varname, "") propMap.put(ConfVars.METASTORE_END_FUNCTION_LISTENERS.varname, "") + // SPARK-21451: Spark will gather all `spark.hadoop.*` properties from a `SparkConf` to a + // Hadoop Configuration internally, as long as it happens after SparkContext initialized. + // Some instances such as `CliSessionState` used in `SparkSQLCliDriver` may also rely on these + // Configuration. But it happens before SparkContext initialized, we need to take them from + // system properties in the form of regular hadoop configurations. + SparkHadoopUtil.get.appendSparkHadoopConfigs(sys.props.toMap, propMap) + propMap.toMap } @@ -414,7 +439,7 @@ private[spark] object HiveUtils extends Logging { protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") @@ -437,7 +462,7 @@ private[spark] object HiveUtils extends Logging { protected def toHiveStructString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 16c1103dd1ea..cc8907a0bbc9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -39,8 +39,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -65,7 +67,7 @@ class HadoopTableReader( @transient private val tableDesc: TableDesc, @transient private val sparkSession: SparkSession, hadoopConf: Configuration) - extends TableReader with Logging { + extends TableReader with CastSupport with Logging { // Hadoop honors "mapreduce.job.maps" as hint, // but will ignore when mapreduce.jobtracker.address is "local". @@ -86,6 +88,8 @@ class HadoopTableReader( private val _broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + override def conf: SQLConf = sparkSession.sessionState.conf + override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, @@ -162,8 +166,8 @@ class HadoopTableReader( if (!sparkSession.sessionState.conf.verifyPartitionPath) { partitionToDeserializer } else { - var existPathSet = collection.mutable.Set[String]() - var pathPatternSet = collection.mutable.Set[String]() + val existPathSet = collection.mutable.Set[String]() + val pathPatternSet = collection.mutable.Set[String]() partitionToDeserializer.filter { case (partition, partDeserializer) => def updateExistPathSetByPathPattern(pathPatternStr: String) { @@ -181,8 +185,8 @@ class HadoopTableReader( } val partPath = partition.getDataLocation - val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); - var pathPatternStr = getPathPatternByPath(partNum, partPath) + val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size() + val pathPatternStr = getPathPatternByPath(partNum, partPath) if (!pathPatternSet.contains(pathPatternStr)) { pathPatternSet += pathPatternStr updateExistPathSetByPathPattern(pathPatternStr) @@ -227,7 +231,7 @@ class HadoopTableReader( def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = partitionKeys.indexOf(attr) - row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) + row(ordinal) = cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 16a80f9fff45..ee3eb2ee8abe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -38,6 +38,12 @@ private[hive] trait HiveClient { /** Returns the configuration for the given key in the current session. */ def getConf(key: String, defaultValue: String): String + /** + * Return the associated Hive SessionState of this [[HiveClientImpl]] + * @return [[Any]] not SessionState to avoid linkage error + */ + def getState: Any + /** * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will * result in one string. @@ -84,10 +90,15 @@ private[hive] trait HiveClient { def dropTable(dbName: String, tableName: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit /** Alter a table whose name matches the one specified in `table`, assuming it exists. */ - final def alterTable(table: CatalogTable): Unit = alterTable(table.identifier.table, table) + final def alterTable(table: CatalogTable): Unit = { + alterTable(table.database, table.identifier.table, table) + } - /** Updates the given table with new metadata, optionally renaming the table. */ - def alterTable(tableName: String, table: CatalogTable): Unit + /** + * Updates the given table with new metadata, optionally renaming the table or + * moving across different database. + */ + def alterTable(dbName: String, tableName: String, table: CatalogTable): Unit /** Creates a new database with the given name. */ def createDatabase(database: CatalogDatabase, ignoreIfExists: Boolean): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 387ec4f96723..c4e48c9360db 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -21,22 +21,25 @@ import java.io.{File, PrintStream} import java.util.Locale import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.language.reflectiveCalls import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} -import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema} +import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Order} import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer.HIVE_COLUMN_ORDER_ASC import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.AnalysisException @@ -48,6 +51,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.client.HiveClientImpl._ import org.apache.spark.sql.types._ import org.apache.spark.util.{CircularBuffer, Utils} @@ -104,106 +108,87 @@ private[hive] class HiveClientImpl( // Create an internal session state for this HiveClientImpl. val state: SessionState = { val original = Thread.currentThread().getContextClassLoader - // Switch to the initClassLoader. - Thread.currentThread().setContextClassLoader(initClassLoader) - - // Set up kerberos credentials for UserGroupInformation.loginUser within - // current class loader - if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { - val principalName = sparkConf.get("spark.yarn.principal") - val keytabFileName = sparkConf.get("spark.yarn.keytab") - if (!new File(keytabFileName).exists()) { - throw new SparkException(s"Keytab file: ${keytabFileName}" + - " specified in spark.yarn.keytab does not exist") - } else { - logInfo("Attempting to login to Kerberos" + - s" using principal: ${principalName} and keytab: ${keytabFileName}") - UserGroupInformation.loginUserFromKeytab(principalName, keytabFileName) + if (clientLoader.isolationOn) { + // Switch to the initClassLoader. + Thread.currentThread().setContextClassLoader(initClassLoader) + // Set up kerberos credentials for UserGroupInformation.loginUser within current class loader + if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { + val principal = sparkConf.get("spark.yarn.principal") + val keytab = sparkConf.get("spark.yarn.keytab") + SparkHadoopUtil.get.loginUserFromKeytab(principal, keytab) } - } - - def isCliSessionState(state: SessionState): Boolean = { - var temp: Class[_] = if (state != null) state.getClass else null - var found = false - while (temp != null && !found) { - found = temp.getName == "org.apache.hadoop.hive.cli.CliSessionState" - temp = temp.getSuperclass + try { + newState() + } finally { + Thread.currentThread().setContextClassLoader(original) } - found - } - - val ret = try { - // originState will be created if not exists, will never be null - val originalState = SessionState.get() - if (isCliSessionState(originalState)) { - // In `SparkSQLCLIDriver`, we have already started a `CliSessionState`, - // which contains information like configurations from command line. Later - // we call `SparkSQLEnv.init()` there, which would run into this part again. - // so we should keep `conf` and reuse the existing instance of `CliSessionState`. - originalState - } else { - val hiveConf = new HiveConf(classOf[SessionState]) - // 1: we set all confs in the hadoopConf to this hiveConf. - // This hadoopConf contains user settings in Hadoop's core-site.xml file - // and Hive's hive-site.xml file. Note, we load hive-site.xml file manually in - // SharedState and put settings in this hadoopConf instead of relying on HiveConf - // to load user settings. Otherwise, HiveConf's initialize method will override - // settings in the hadoopConf. This issue only shows up when spark.sql.hive.metastore.jars - // is not set to builtin. When spark.sql.hive.metastore.jars is builtin, the classpath - // has hive-site.xml. So, HiveConf will use that to override its default values. - hadoopConf.iterator().asScala.foreach { entry => - val key = entry.getKey - val value = entry.getValue - if (key.toLowerCase(Locale.ROOT).contains("password")) { - logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=xxx") - } else { - logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=$value") - } - hiveConf.set(key, value) - } - // HiveConf is a Hadoop Configuration, which has a field of classLoader and - // the initial value will be the current thread's context class loader - // (i.e. initClassLoader at here). - // We call initialConf.setClassLoader(initClassLoader) at here to make - // this action explicit. - hiveConf.setClassLoader(initClassLoader) - // 2: we set all spark confs to this hiveConf. - sparkConf.getAll.foreach { case (k, v) => - if (k.toLowerCase(Locale.ROOT).contains("password")) { - logDebug(s"Applying Spark config to Hive Conf: $k=xxx") - } else { - logDebug(s"Applying Spark config to Hive Conf: $k=$v") - } - hiveConf.set(k, v) - } - // 3: we set all entries in config to this hiveConf. - extraConfig.foreach { case (k, v) => - if (k.toLowerCase(Locale.ROOT).contains("password")) { - logDebug(s"Applying extra config to HiveConf: $k=xxx") - } else { - logDebug(s"Applying extra config to HiveConf: $k=$v") - } - hiveConf.set(k, v) - } - val state = new SessionState(hiveConf) - if (clientLoader.cachedHive != null) { - Hive.set(clientLoader.cachedHive.asInstanceOf[Hive]) + } else { + // Isolation off means we detect a CliSessionState instance in current thread. + // 1: Inside the spark project, we have already started a CliSessionState in + // `SparkSQLCLIDriver`, which contains configurations from command lines. Later, we call + // `SparkSQLEnv.init()` there, which would new a hive client again. so we should keep those + // configurations and reuse the existing instance of `CliSessionState`. In this case, + // SessionState.get will always return a CliSessionState. + // 2: In another case, a user app may start a CliSessionState outside spark project with built + // in hive jars, which will turn off isolation, if SessionSate.detachSession is + // called to remove the current state after that, hive client created later will initialize + // its own state by newState() + val ret = SessionState.get + if (ret != null) { + // hive.metastore.warehouse.dir is determined in SharedState after the CliSessionState + // instance constructed, we need to follow that change here. + Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)).foreach { dir => + ret.getConf.setVar(ConfVars.METASTOREWAREHOUSE, dir) } - SessionState.start(state) - state.out = new PrintStream(outputBuffer, true, "UTF-8") - state.err = new PrintStream(outputBuffer, true, "UTF-8") - state + ret + } else { + newState() } - } finally { - Thread.currentThread().setContextClassLoader(original) } - ret } // Log the default warehouse location. logInfo( s"Warehouse location for Hive client " + - s"(version ${version.fullVersion}) is ${conf.get("hive.metastore.warehouse.dir")}") + s"(version ${version.fullVersion}) is ${conf.getVar(ConfVars.METASTOREWAREHOUSE)}") + + private def newState(): SessionState = { + val hiveConf = new HiveConf(classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + hiveConf.setClassLoader(initClassLoader) + + // 1: Take all from the hadoopConf to this hiveConf. + // This hadoopConf contains user settings in Hadoop's core-site.xml file + // and Hive's hive-site.xml file. Note, we load hive-site.xml file manually in + // SharedState and put settings in this hadoopConf instead of relying on HiveConf + // to load user settings. Otherwise, HiveConf's initialize method will override + // settings in the hadoopConf. This issue only shows up when spark.sql.hive.metastore.jars + // is not set to builtin. When spark.sql.hive.metastore.jars is builtin, the classpath + // has hive-site.xml. So, HiveConf will use that to override its default values. + // 2: we set all spark confs to this hiveConf. + // 3: we set all entries in config to this hiveConf. + (hadoopConf.iterator().asScala.map(kv => kv.getKey -> kv.getValue) + ++ sparkConf.getAll.toMap ++ extraConfig).foreach { case (k, v) => + logDebug( + s""" + |Applying Hadoop/Hive/Spark and extra properties to Hive Conf: + |$k=${if (k.toLowerCase(Locale.ROOT).contains("password")) "xxx" else v} + """.stripMargin) + hiveConf.set(k, v) + } + val state = new SessionState(hiveConf) + if (clientLoader.cachedHive != null) { + Hive.set(clientLoader.cachedHive.asInstanceOf[Hive]) + } + SessionState.start(state) + state.out = new PrintStream(outputBuffer, true, "UTF-8") + state.err = new PrintStream(outputBuffer, true, "UTF-8") + state + } /** Returns the configuration for the current session. */ def conf: HiveConf = state.getConf @@ -268,6 +253,9 @@ private[hive] class HiveClientImpl( } } + /** Return the associated Hive [[SessionState]] of this [[HiveClientImpl]] */ + override def getState: SessionState = withHiveState(state) + /** * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. */ @@ -374,10 +362,30 @@ private[hive] class HiveClientImpl( Option(client.getTable(dbName, tableName, false)).map { h => // Note: Hive separates partition columns and the schema, but for us the // partition columns are part of the schema + val cols = h.getCols.asScala.map(fromHiveColumn) val partCols = h.getPartCols.asScala.map(fromHiveColumn) - val schema = StructType(h.getCols.asScala.map(fromHiveColumn) ++ partCols) + val schema = StructType(cols ++ partCols) + + val bucketSpec = if (h.getNumBuckets > 0) { + val sortColumnOrders = h.getSortCols.asScala + // Currently Spark only supports columns to be sorted in ascending order + // but Hive can support both ascending and descending order. If all the columns + // are sorted in ascending order, only then propagate the sortedness information + // to downstream processing / optimizations in Spark + // TODO: In future we can have Spark support columns sorted in descending order + val allAscendingSorted = sortColumnOrders.forall(_.getOrder == HIVE_COLUMN_ORDER_ASC) + + val sortColumnNames = if (allAscendingSorted) { + sortColumnOrders.map(_.getCol) + } else { + Seq.empty + } + Option(BucketSpec(h.getNumBuckets, h.getBucketCols.asScala, sortColumnNames)) + } else { + None + } - // Skew spec, storage handler, and bucketing info can't be mapped to CatalogTable (yet) + // Skew spec and storage handler can't be mapped to CatalogTable (yet) val unsupportedFeatures = ArrayBuffer.empty[String] if (!h.getSkewedColNames.isEmpty) { @@ -388,16 +396,55 @@ private[hive] class HiveClientImpl( unsupportedFeatures += "storage handler" } - if (!h.getBucketCols.isEmpty) { - unsupportedFeatures += "bucketing" - } - if (h.getTableType == HiveTableType.VIRTUAL_VIEW && partCols.nonEmpty) { unsupportedFeatures += "partitioned view" } val properties = Option(h.getParameters).map(_.asScala.toMap).orNull + // Hive-generated Statistics are also recorded in ignoredProperties + val ignoredProperties = scala.collection.mutable.Map.empty[String, String] + for (key <- HiveStatisticsProperties; value <- properties.get(key)) { + ignoredProperties += key -> value + } + + val excludedTableProperties = HiveStatisticsProperties ++ Set( + // The property value of "comment" is moved to the dedicated field "comment" + "comment", + // For EXTERNAL_TABLE, the table properties has a particular field "EXTERNAL". This is added + // in the function toHiveTable. + "EXTERNAL" + ) + + val filteredProperties = properties.filterNot { + case (key, _) => excludedTableProperties.contains(key) + } + val comment = properties.get("comment") + + // Here we are reading statistics from Hive. + // Note that this statistics could be overridden by Spark's statistics if that's available. + val totalSize = properties.get(StatsSetupConst.TOTAL_SIZE).map(BigInt(_)) + val rawDataSize = properties.get(StatsSetupConst.RAW_DATA_SIZE).map(BigInt(_)) + val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)).filter(_ >= 0) + // TODO: check if this estimate is valid for tables after partition pruning. + // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be + // relatively cheap if parameters for the table are populated into the metastore. + // Currently, only totalSize, rawDataSize, and rowCount are used to build the field `stats` + // TODO: stats should include all the other two fields (`numFiles` and `numPartitions`). + // (see StatsSetupConst in Hive) + val stats = + // When table is external, `totalSize` is always zero, which will influence join strategy + // so when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero, + // return None. Later, we will use the other ways to estimate the statistics. + if (totalSize.isDefined && totalSize.get > 0L) { + Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount)) + } else if (rawDataSize.isDefined && rawDataSize.get > 0) { + Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount)) + } else { + // TODO: still fill the rowCount even if sizeInBytes is empty. Might break anything? + None + } + CatalogTable( identifier = TableIdentifier(h.getTableName, Option(h.getDbName)), tableType = h.getTableType match { @@ -409,9 +456,11 @@ private[hive] class HiveClientImpl( }, schema = schema, partitionColumnNames = partCols.map(_.name), - // We can not populate bucketing information for Hive tables as Spark SQL has a different - // implementation of hash function from Hive. - bucketSpec = None, + // If the table is written by Spark, we will put bucketing information in table properties, + // and will always overwrite the bucket spec in hive metastore by the bucketing information + // in table properties. This means, if we have bucket spec in both hive metastore and + // table properties, we will trust the one in table properties. + bucketSpec = bucketSpec, owner = h.getOwner, createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, @@ -433,13 +482,15 @@ private[hive] class HiveClientImpl( ), // For EXTERNAL_TABLE, the table properties has a particular field "EXTERNAL". This is added // in the function toHiveTable. - properties = properties.filter(kv => kv._1 != "comment" && kv._1 != "EXTERNAL"), - comment = properties.get("comment"), + properties = filteredProperties, + stats = stats, + comment = comment, // In older versions of Spark(before 2.2.0), we expand the view original text and store // that into `viewExpandedText`, and that should be used in view resolution. So we get // `viewExpandedText` instead of `viewOriginalText` for viewText here. viewText = Option(h.getViewExpandedText), - unsupportedFeatures = unsupportedFeatures) + unsupportedFeatures = unsupportedFeatures, + ignoredProperties = ignoredProperties.toMap) } } @@ -455,10 +506,18 @@ private[hive] class HiveClientImpl( shim.dropTable(client, dbName, tableName, true, ignoreIfNotExists, purge) } - override def alterTable(tableName: String, table: CatalogTable): Unit = withHiveState { - val hiveTable = toHiveTable(table, Some(userName)) + override def alterTable( + dbName: String, + tableName: String, + table: CatalogTable): Unit = withHiveState { + // getTableOption removes all the Hive-specific properties. Here, we fill them back to ensure + // these properties are still available to the others that share the same Hive metastore. + // If users explicitly alter these Hive-specific properties through ALTER TABLE DDL, we respect + // these user-specified values. + val hiveTable = toHiveTable( + table.copy(properties = table.ignoredProperties ++ table.properties), Some(userName)) // Do not use `table.qualifiedName` here because this may be a rename - val qualifiedTableName = s"${table.database}.$tableName" + val qualifiedTableName = s"$dbName.$tableName" shim.alterTable(client, qualifiedTableName, hiveTable) } @@ -801,7 +860,12 @@ private[hive] object HiveClientImpl { throw new SparkException("Cannot recognize hive type string: " + hc.getType, e) } - val metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build() + val metadata = if (hc.getType != columnType.catalogString) { + new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build() + } else { + Metadata.empty + } + val field = StructField( name = hc.getName, dataType = columnType, @@ -839,7 +903,7 @@ private[hive] object HiveClientImpl { } // after SPARK-19279, it is not allowed to create a hive table with an empty schema, // so here we should not add a default col schema - if (schema.isEmpty && DDLUtils.isDatasourceTable(table)) { + if (schema.isEmpty && HiveExternalCatalog.isDatasourceTable(table)) { // This is a hack to preserve existing behavior. Before Spark 2.0, we do not // set a default serde here (this was done in Hive), and so if the user provides // an empty schema Hive would automatically populate the schema with a single @@ -871,6 +935,23 @@ private[hive] object HiveClientImpl { hiveTable.setViewOriginalText(t) hiveTable.setViewExpandedText(t) } + + table.bucketSpec match { + case Some(bucketSpec) if DDLUtils.isHiveTable(table) => + hiveTable.setNumBuckets(bucketSpec.numBuckets) + hiveTable.setBucketCols(bucketSpec.bucketColumnNames.toList.asJava) + + if (bucketSpec.sortColumnNames.nonEmpty) { + hiveTable.setSortCols( + bucketSpec.sortColumnNames + .map(col => new Order(col, HIVE_COLUMN_ORDER_ASC)) + .toList + .asJava + ) + } + case _ => + } + hiveTable } @@ -900,6 +981,7 @@ private[hive] object HiveClientImpl { tpart.setTableName(ht.getTableName) tpart.setValues(partValues.asJava) tpart.setSd(storageDesc) + tpart.setParameters(mutable.Map(p.parameters.toSeq: _*).asJava) new HivePartition(ht, tpart) } @@ -921,4 +1003,14 @@ private[hive] object HiveClientImpl { parameters = if (hp.getParameters() != null) hp.getParameters().asScala.toMap else Map.empty) } + + // Below is the key of table properties for storing Hive-generated statistics + private val HiveStatisticsProperties = Set( + StatsSetupConst.COLUMN_STATS_ACCURATE, + StatsSetupConst.NUM_FILES, + StatsSetupConst.NUM_PARTITIONS, + StatsSetupConst.ROW_COUNT, + StatsSetupConst.RAW_DATA_SIZE, + StatsSetupConst.TOTAL_SIZE + ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 7abb9f06b131..cde20da186ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ +import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -46,6 +47,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTableParti import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -589,18 +591,67 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet - filters.collect { - case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => - s"${a.name} ${op.symbol} $v" - case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => - s"$v ${op.symbol} ${a.name}" - case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + object ExtractableLiteral { + def unapply(expr: Expression): Option[String] = expr match { + case Literal(value, _: IntegralType) => Some(value.toString) + case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) + case _ => None + } + } + + object ExtractableLiterals { + def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { + exprs.map(ExtractableLiteral.unapply).foldLeft(Option(Seq.empty[String])) { + case (Some(accum), Some(value)) => Some(accum :+ value) + case _ => None + } + } + } + + object ExtractableValues { + private lazy val valueToLiteralString: PartialFunction[Any, String] = { + case value: Byte => value.toString + case value: Short => value.toString + case value: Int => value.toString + case value: Long => value.toString + case value: UTF8String => quoteStringLiteral(value.toString) + } + + def unapply(values: Set[Any]): Option[Seq[String]] = { + values.toSeq.foldLeft(Option(Seq.empty[String])) { + case (Some(accum), value) if valueToLiteralString.isDefinedAt(value) => + Some(accum :+ valueToLiteralString(value)) + case _ => None + } + } + } + + def convertInToOr(a: Attribute, values: Seq[String]): String = { + values.map(value => s"${a.name} = $value").mkString("(", " or ", ")") + } + + lazy val convert: PartialFunction[Expression, String] = { + case In(a: Attribute, ExtractableLiterals(values)) + if !varcharKeys.contains(a.name) && values.nonEmpty => + convertInToOr(a, values) + case InSet(a: Attribute, ExtractableValues(values)) + if !varcharKeys.contains(a.name) && values.nonEmpty => + convertInToOr(a, values) + case op @ BinaryComparison(a: Attribute, ExtractableLiteral(value)) if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" - case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + s"${a.name} ${op.symbol} $value" + case op @ BinaryComparison(ExtractableLiteral(value), a: Attribute) if !varcharKeys.contains(a.name) => - s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" - }.mkString(" and ") + s"$value ${op.symbol} ${a.name}" + case op @ And(expr1, expr2) + if convert.isDefinedAt(expr1) || convert.isDefinedAt(expr2) => + (convert.lift(expr1) ++ convert.lift(expr2)).mkString("(", " and ", ")") + case op @ Or(expr1, expr2) + if convert.isDefinedAt(expr1) && convert.isDefinedAt(expr2) => + s"(${convert(expr1)} or ${convert(expr2)})" + } + + filters.map(convert.lift).collect { case Some(filterString) => filterString }.mkString(" and ") } private def quoteStringLiteral(str: String): String = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index e95f9ea48043..930f0dd4b32b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -22,7 +22,6 @@ import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util -import scala.language.reflectiveCalls import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} @@ -93,7 +92,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "14" | "0.14" | "0.14.0" => hive.v14 case "1.0" | "1.0.0" => hive.v1_0 case "1.1" | "1.1.0" => hive.v1_1 - case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 + case "1.2" | "1.2.0" | "1.2.1" | "1.2.2" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index f9635e36549e..c14154a3b3c2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -56,7 +56,7 @@ package object client { "net.hydromatic:linq4j", "net.hydromatic:quidem")) - case object v1_2 extends HiveVersion("1.2.1", + case object v1_2 extends HiveVersion("1.2.2", exclusions = Seq("eigenbase:eigenbase-properties", "org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 41c6b18e9d79..65e8b4e3c725 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -62,7 +62,7 @@ case class CreateHiveTableAsSelectCommand( Map(), query, overwrite = false, - ifNotExists = false)).toRdd + ifPartitionNotExists = false)).toRdd } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data @@ -78,7 +78,7 @@ case class CreateHiveTableAsSelectCommand( Map(), query, overwrite = true, - ifNotExists = false)).toRdd + ifPartitionNotExists = false)).toRdd } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 666548d1a490..48d0b4a63e54 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -30,13 +30,15 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType} import org.apache.spark.util.Utils @@ -50,14 +52,16 @@ import org.apache.spark.util.Utils private[hive] case class HiveTableScanExec( requestedAttributes: Seq[Attribute], - relation: CatalogRelation, + relation: HiveTableRelation, partitionPruningPred: Seq[Expression])( @transient private val sparkSession: SparkSession) - extends LeafExecNode { + extends LeafExecNode with CastSupport { require(partitionPruningPred.isEmpty || relation.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + override def conf: SQLConf = sparkSession.sessionState.conf + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -104,7 +108,7 @@ case class HiveTableScanExec( hadoopConf) private def castFromString(value: String, dataType: DataType) = { - Cast(Literal(value), dataType).eval(null) + cast(Literal(value), dataType).eval(null) } private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { @@ -205,8 +209,8 @@ case class HiveTableScanExec( val input: AttributeSeq = relation.output HiveTableScanExec( requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), - relation.canonicalized.asInstanceOf[CatalogRelation], - partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) + relation.canonicalized, + QueryPlan.normalizePredicates(partitionPruningPred, input))(sparkSession) } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala new file mode 100644 index 000000000000..918c8be00d69 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.language.existentials + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.FileUtils +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.mapred._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.hive.client.HiveClientImpl + +/** + * Command for writing the results of `query` to file system. + * + * The syntax of using this command in SQL is: + * {{{ + * INSERT OVERWRITE [LOCAL] DIRECTORY + * path + * [ROW FORMAT row_format] + * [STORED AS file_format] + * SELECT ... + * }}} + * + * @param isLocal whether the path specified in `storage` is a local directory + * @param storage storage format used to describe how the query result is stored. + * @param query the logical plan representing data to write to + * @param overwrite whether overwrites existing directory + */ +case class InsertIntoHiveDirCommand( + isLocal: Boolean, + storage: CatalogStorageFormat, + query: LogicalPlan, + overwrite: Boolean) extends SaveAsHiveFile { + + override def children: Seq[LogicalPlan] = query :: Nil + + override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + assert(children.length == 1) + assert(storage.locationUri.nonEmpty) + + val hiveTable = HiveClientImpl.toHiveTable(CatalogTable( + identifier = TableIdentifier(storage.locationUri.get.toString, Some("default")), + tableType = org.apache.spark.sql.catalyst.catalog.CatalogTableType.VIEW, + storage = storage, + schema = query.schema + )) + hiveTable.getMetadata.put(serdeConstants.SERIALIZATION_LIB, + storage.serde.getOrElse(classOf[LazySimpleSerDe].getName)) + + val tableDesc = new TableDesc( + hiveTable.getInputFormatClass, + hiveTable.getOutputFormatClass, + hiveTable.getMetadata + ) + + val hadoopConf = sparkSession.sessionState.newHadoopConf() + val jobConf = new JobConf(hadoopConf) + + val targetPath = new Path(storage.locationUri.get) + val writeToPath = + if (isLocal) { + val localFileSystem = FileSystem.getLocal(jobConf) + localFileSystem.makeQualified(targetPath) + } else { + val qualifiedPath = FileUtils.makeQualified(targetPath, hadoopConf) + val dfs = qualifiedPath.getFileSystem(jobConf) + if (!dfs.exists(qualifiedPath)) { + dfs.mkdirs(qualifiedPath.getParent) + } + qualifiedPath + } + + val tmpPath = getExternalTmpPath(sparkSession, hadoopConf, writeToPath) + val fileSinkConf = new org.apache.spark.sql.hive.HiveShim.ShimFileSinkDesc( + tmpPath.toString, tableDesc, false) + + try { + saveAsHiveFile( + sparkSession = sparkSession, + plan = children.head, + hadoopConf = hadoopConf, + fileSinkConf = fileSinkConf, + outputLocation = tmpPath.toString) + + val fs = writeToPath.getFileSystem(hadoopConf) + if (overwrite && fs.exists(writeToPath)) { + fs.listStatus(writeToPath).foreach { existFile => + if (Option(existFile.getPath) != createdTempDir) fs.delete(existFile.getPath, true) + } + } + + fs.listStatus(tmpPath).foreach { + tmpFile => fs.rename(tmpFile.getPath, writeToPath) + } + } catch { + case e: Throwable => + throw new SparkException( + "Failed inserting overwrite directory " + storage.locationUri.get, e) + } finally { + deleteExternalTmpPath(hadoopConf) + } + + Seq.empty[Row] + } +} + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3682dc850790..e5b59ed7a1a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,31 +17,22 @@ package org.apache.spark.sql.hive.execution -import java.io.IOException -import java.net.URI -import java.text.SimpleDateFormat -import java.util.{Date, Locale, Random} - import scala.util.control.NonFatal -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.hive.common.FileUtils -import org.apache.hadoop.hive.ql.exec.TaskRunner +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.spark.internal.io.FileCommitProtocol -import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.RunnableCommand -import org.apache.spark.sql.execution.datasources.FileFormatWriter +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.client.{HiveClientImpl, HiveVersion} -import org.apache.spark.SparkException +import org.apache.spark.sql.hive.client.HiveClientImpl /** @@ -71,159 +62,28 @@ import org.apache.spark.SparkException * }}}. * @param query the logical plan representing data to write to. * @param overwrite overwrite existing table or partitions. - * @param ifNotExists If true, only write if the table or partition does not exist. + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. */ case class InsertIntoHiveTable( table: CatalogTable, partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) extends RunnableCommand { - - override protected def innerChildren: Seq[LogicalPlan] = query :: Nil - - var createdTempDir: Option[Path] = None - - private def executionId: String = { - val rand: Random = new Random - val format = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS", Locale.US) - "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong) - } - - private def getStagingDir( - inputPath: Path, - hadoopConf: Configuration, - stagingDir: String): Path = { - val inputPathUri: URI = inputPath.toUri - val inputPathName: String = inputPathUri.getPath - val fs: FileSystem = inputPath.getFileSystem(hadoopConf) - val stagingPathName: String = - if (inputPathName.indexOf(stagingDir) == -1) { - new Path(inputPathName, stagingDir).toString - } else { - inputPathName.substring(0, inputPathName.indexOf(stagingDir) + stagingDir.length) - } - val dir: Path = - fs.makeQualified( - new Path(stagingPathName + "_" + executionId + "-" + TaskRunner.getTaskRunnerID)) - logDebug("Created staging dir = " + dir + " for path = " + inputPath) - try { - if (!FileUtils.mkdir(fs, dir, true, hadoopConf)) { - throw new IllegalStateException("Cannot create staging directory '" + dir.toString + "'") - } - createdTempDir = Some(dir) - fs.deleteOnExit(dir) - } catch { - case e: IOException => - throw new RuntimeException( - "Cannot create staging directory '" + dir.toString + "': " + e.getMessage, e) - } - dir - } - - private def getExternalScratchDir( - extURI: URI, - hadoopConf: Configuration, - stagingDir: String): Path = { - getStagingDir( - new Path(extURI.getScheme, extURI.getAuthority, extURI.getPath), - hadoopConf, - stagingDir) - } - - def getExternalTmpPath( - path: Path, - hiveVersion: HiveVersion, - hadoopConf: Configuration, - stagingDir: String, - scratchDir: String): Path = { - import org.apache.spark.sql.hive.client.hive._ - - // Before Hive 1.1, when inserting into a table, Hive will create the staging directory under - // a common scratch directory. After the writing is finished, Hive will simply empty the table - // directory and move the staging directory to it. - // After Hive 1.1, Hive will create the staging directory under the table directory, and when - // moving staging directory to table directory, Hive will still empty the table directory, but - // will exclude the staging directory there. - // We have to follow the Hive behavior here, to avoid troubles. For example, if we create - // staging directory under the table director for Hive prior to 1.1, the staging directory will - // be removed by Hive when Hive is trying to empty the table directory. - val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) - - // Ensure all the supported versions are considered here. - assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == - allSupportedHiveVersions) - - if (hiveVersionsUsingOldExternalTempPath.contains(hiveVersion)) { - oldVersionExternalTempPath(path, hadoopConf, scratchDir) - } else if (hiveVersionsUsingNewExternalTempPath.contains(hiveVersion)) { - newVersionExternalTempPath(path, hadoopConf, stagingDir) - } else { - throw new IllegalStateException("Unsupported hive version: " + hiveVersion.fullVersion) - } - } - - // Mostly copied from Context.java#getExternalTmpPath of Hive 0.13 - def oldVersionExternalTempPath( - path: Path, - hadoopConf: Configuration, - scratchDir: String): Path = { - val extURI: URI = path.toUri - val scratchPath = new Path(scratchDir, executionId) - var dirPath = new Path( - extURI.getScheme, - extURI.getAuthority, - scratchPath.toUri.getPath + "-" + TaskRunner.getTaskRunnerID()) - - try { - val fs: FileSystem = dirPath.getFileSystem(hadoopConf) - dirPath = new Path(fs.makeQualified(dirPath).toString()) - - if (!FileUtils.mkdir(fs, dirPath, true, hadoopConf)) { - throw new IllegalStateException("Cannot create staging directory: " + dirPath.toString) - } - createdTempDir = Some(dirPath) - fs.deleteOnExit(dirPath) - } catch { - case e: IOException => - throw new RuntimeException("Cannot create staging directory: " + dirPath.toString, e) - } - dirPath - } - - // Mostly copied from Context.java#getExternalTmpPath of Hive 1.2 - def newVersionExternalTempPath( - path: Path, - hadoopConf: Configuration, - stagingDir: String): Path = { - val extURI: URI = path.toUri - if (extURI.getScheme == "viewfs") { - getExtTmpPathRelTo(path.getParent, hadoopConf, stagingDir) - } else { - new Path(getExternalScratchDir(extURI, hadoopConf, stagingDir), "-ext-10000") - } - } + ifPartitionNotExists: Boolean) extends SaveAsHiveFile { - def getExtTmpPathRelTo( - path: Path, - hadoopConf: Configuration, - stagingDir: String): Path = { - new Path(getStagingDir(path, hadoopConf, stagingDir), "-ext-10000") // Hive uses 10000 - } + override def children: Seq[LogicalPlan] = query :: Nil /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession): Seq[Row] = { - val sessionState = sparkSession.sessionState + override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + assert(children.length == 1) + val externalCatalog = sparkSession.sharedState.externalCatalog - val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version - val hadoopConf = sessionState.newHadoopConf() - val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") - val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") + val hadoopConf = sparkSession.sessionState.newHadoopConf() val hiveQlTable = HiveClientImpl.toHiveTable(table) // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer @@ -238,23 +98,8 @@ case class InsertIntoHiveTable( hiveQlTable.getMetadata ) val tableLocation = hiveQlTable.getDataLocation - val tmpLocation = - getExternalTmpPath(tableLocation, hiveVersion, hadoopConf, stagingDir, scratchDir) + val tmpLocation = getExternalTmpPath(sparkSession, hadoopConf, tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean - - if (isCompressed) { - // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", - // "mapreduce.output.fileoutputformat.compress.codec", and - // "mapreduce.output.fileoutputformat.compress.type" - // have no impact on ORC because it uses table properties to store compression information. - hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") - fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(hadoopConf - .get("mapreduce.output.fileoutputformat.compress.codec")) - fileSinkConf.setCompressType(hadoopConf - .get("mapreduce.output.fileoutputformat.compress.type")) - } val numDynamicPartitions = partition.values.count(_.isEmpty) val numStaticPartitions = partition.values.count(_.nonEmpty) @@ -295,11 +140,26 @@ case class InsertIntoHiveTable( } } - val committer = FileCommitProtocol.instantiate( - sparkSession.sessionState.conf.fileCommitProtocolClass, - jobId = java.util.UUID.randomUUID().toString, - outputPath = tmpLocation.toString, - isAppend = false) + table.bucketSpec match { + case Some(bucketSpec) => + // Writes to bucketed hive tables are allowed only if user does not care about maintaining + // table's bucketing ie. both "hive.enforce.bucketing" and "hive.enforce.sorting" are + // set to false + val enforceBucketingConfig = "hive.enforce.bucketing" + val enforceSortingConfig = "hive.enforce.sorting" + + val message = s"Output Hive table ${table.identifier} is bucketed but Spark" + + "currently does NOT populate bucketed output which is compatible with Hive." + + if (hadoopConf.get(enforceBucketingConfig, "true").toBoolean || + hadoopConf.get(enforceSortingConfig, "true").toBoolean) { + throw new AnalysisException(message) + } else { + logWarning(message + s" Inserting data anyways since both $enforceBucketingConfig and " + + s"$enforceSortingConfig are set to false.") + } + case _ => // do nothing since table has no bucketing + } val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name => query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse { @@ -308,17 +168,13 @@ case class InsertIntoHiveTable( }.asInstanceOf[Attribute] } - FileFormatWriter.write( + saveAsHiveFile( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, - fileFormat = new HiveFileFormat(fileSinkConf), - committer = committer, - outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty), + plan = children.head, hadoopConf = hadoopConf, - partitionColumns = partitionAttributes, - bucketSpec = None, - refreshFunction = _ => (), - options = Map.empty) + fileSinkConf = fileSinkConf, + outputLocation = tmpLocation.toString, + partitionAttributes = partitionAttributes) if (partition.nonEmpty) { if (numDynamicPartitions > 0) { @@ -342,7 +198,7 @@ case class InsertIntoHiveTable( var doHiveOverwrite = overwrite - if (oldPart.isEmpty || !ifNotExists) { + if (oldPart.isEmpty || !ifPartitionNotExists) { // SPARK-18107: Insert overwrite runs much slower than hive-client. // Newer Hive largely improves insert overwrite performance. As Spark uses older Hive // version and we may not want to catch up new Hive version every time. We delete the @@ -386,17 +242,14 @@ case class InsertIntoHiveTable( // Attempt to delete the staging directory and the inclusive files. If failed, the files are // expected to be dropped at the normal termination of VM since deleteOnExit is used. - try { - createdTempDir.foreach { path => path.getFileSystem(hadoopConf).delete(path, true) } - } catch { - case NonFatal(e) => - logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e) - } + deleteExternalTmpPath(hadoopConf) // un-cache this table. sparkSession.catalog.uncacheTable(table.identifier.quotedString) sparkSession.sessionState.catalog.refreshTable(table.identifier) + CommandUtils.updateTableStats(sparkSession, table) + // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala new file mode 100644 index 000000000000..2d74ef040ef5 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.{File, IOException} +import java.net.URI +import java.text.SimpleDateFormat +import java.util.{Date, Locale, Random} + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.FileUtils +import org.apache.hadoop.hive.ql.exec.TaskRunner + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.execution.datasources.FileFormatWriter +import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.hive.client.HiveVersion + +// Base trait from which all hive insert statement physical execution extends. +private[hive] trait SaveAsHiveFile extends DataWritingCommand { + + var createdTempDir: Option[Path] = None + + protected def saveAsHiveFile( + sparkSession: SparkSession, + plan: SparkPlan, + hadoopConf: Configuration, + fileSinkConf: FileSinkDesc, + outputLocation: String, + customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, + partitionAttributes: Seq[Attribute] = Nil): Set[String] = { + + val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean + if (isCompressed) { + // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact on ORC because it uses table properties to store compression information. + hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") + fileSinkConf.setCompressed(true) + fileSinkConf.setCompressCodec(hadoopConf + .get("mapreduce.output.fileoutputformat.compress.codec")) + fileSinkConf.setCompressType(hadoopConf + .get("mapreduce.output.fileoutputformat.compress.type")) + } + + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.fileCommitProtocolClass, + jobId = java.util.UUID.randomUUID().toString, + outputPath = outputLocation) + + FileFormatWriter.write( + sparkSession = sparkSession, + plan = plan, + fileFormat = new HiveFileFormat(fileSinkConf), + committer = committer, + outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations), + hadoopConf = hadoopConf, + partitionColumns = partitionAttributes, + bucketSpec = None, + statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), + options = Map.empty) + } + + protected def getExternalTmpPath( + sparkSession: SparkSession, + hadoopConf: Configuration, + path: Path): Path = { + import org.apache.spark.sql.hive.client.hive._ + + // Before Hive 1.1, when inserting into a table, Hive will create the staging directory under + // a common scratch directory. After the writing is finished, Hive will simply empty the table + // directory and move the staging directory to it. + // After Hive 1.1, Hive will create the staging directory under the table directory, and when + // moving staging directory to table directory, Hive will still empty the table directory, but + // will exclude the staging directory there. + // We have to follow the Hive behavior here, to avoid troubles. For example, if we create + // staging directory under the table director for Hive prior to 1.1, the staging directory will + // be removed by Hive when Hive is trying to empty the table directory. + val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) + + // Ensure all the supported versions are considered here. + assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == + allSupportedHiveVersions) + + val externalCatalog = sparkSession.sharedState.externalCatalog + val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version + val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") + val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") + + if (hiveVersionsUsingOldExternalTempPath.contains(hiveVersion)) { + oldVersionExternalTempPath(path, hadoopConf, scratchDir) + } else if (hiveVersionsUsingNewExternalTempPath.contains(hiveVersion)) { + newVersionExternalTempPath(path, hadoopConf, stagingDir) + } else { + throw new IllegalStateException("Unsupported hive version: " + hiveVersion.fullVersion) + } + } + + protected def deleteExternalTmpPath(hadoopConf: Configuration) : Unit = { + // Attempt to delete the staging directory and the inclusive files. If failed, the files are + // expected to be dropped at the normal termination of VM since deleteOnExit is used. + try { + createdTempDir.foreach { path => + val fs = path.getFileSystem(hadoopConf) + if (fs.delete(path, true)) { + // If we successfully delete the staging directory, remove it from FileSystem's cache. + fs.cancelDeleteOnExit(path) + } + } + } catch { + case NonFatal(e) => + val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") + logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e) + } + } + + // Mostly copied from Context.java#getExternalTmpPath of Hive 0.13 + private def oldVersionExternalTempPath( + path: Path, + hadoopConf: Configuration, + scratchDir: String): Path = { + val extURI: URI = path.toUri + val scratchPath = new Path(scratchDir, executionId) + var dirPath = new Path( + extURI.getScheme, + extURI.getAuthority, + scratchPath.toUri.getPath + "-" + TaskRunner.getTaskRunnerID()) + + try { + val fs: FileSystem = dirPath.getFileSystem(hadoopConf) + dirPath = new Path(fs.makeQualified(dirPath).toString()) + + if (!FileUtils.mkdir(fs, dirPath, true, hadoopConf)) { + throw new IllegalStateException("Cannot create staging directory: " + dirPath.toString) + } + createdTempDir = Some(dirPath) + fs.deleteOnExit(dirPath) + } catch { + case e: IOException => + throw new RuntimeException("Cannot create staging directory: " + dirPath.toString, e) + } + dirPath + } + + // Mostly copied from Context.java#getExternalTmpPath of Hive 1.2 + private def newVersionExternalTempPath( + path: Path, + hadoopConf: Configuration, + stagingDir: String): Path = { + val extURI: URI = path.toUri + if (extURI.getScheme == "viewfs") { + getExtTmpPathRelTo(path.getParent, hadoopConf, stagingDir) + } else { + new Path(getExternalScratchDir(extURI, hadoopConf, stagingDir), "-ext-10000") + } + } + + private def getExtTmpPathRelTo( + path: Path, + hadoopConf: Configuration, + stagingDir: String): Path = { + new Path(getStagingDir(path, hadoopConf, stagingDir), "-ext-10000") // Hive uses 10000 + } + + private def getExternalScratchDir( + extURI: URI, + hadoopConf: Configuration, + stagingDir: String): Path = { + getStagingDir( + new Path(extURI.getScheme, extURI.getAuthority, extURI.getPath), + hadoopConf, + stagingDir) + } + + private def getStagingDir( + inputPath: Path, + hadoopConf: Configuration, + stagingDir: String): Path = { + val inputPathUri: URI = inputPath.toUri + val inputPathName: String = inputPathUri.getPath + val fs: FileSystem = inputPath.getFileSystem(hadoopConf) + var stagingPathName: String = + if (inputPathName.indexOf(stagingDir) == -1) { + new Path(inputPathName, stagingDir).toString + } else { + inputPathName.substring(0, inputPathName.indexOf(stagingDir) + stagingDir.length) + } + + // SPARK-20594: This is a walk-around fix to resolve a Hive bug. Hive requires that the + // staging directory needs to avoid being deleted when users set hive.exec.stagingdir + // under the table directory. + if (FileUtils.isSubDir(new Path(stagingPathName), inputPath, fs) && + !stagingPathName.stripPrefix(inputPathName).stripPrefix(File.separator).startsWith(".")) { + logDebug(s"The staging dir '$stagingPathName' should be a child directory starts " + + "with '.' to avoid being deleted if we set hive.exec.stagingdir under the table " + + "directory.") + stagingPathName = new Path(inputPathName, ".hive-staging").toString + } + + val dir: Path = + fs.makeQualified( + new Path(stagingPathName + "_" + executionId + "-" + TaskRunner.getTaskRunnerID)) + logDebug("Created staging dir = " + dir + " for path = " + inputPath) + try { + if (!FileUtils.mkdir(fs, dir, true, hadoopConf)) { + throw new IllegalStateException("Cannot create staging directory '" + dir.toString + "'") + } + createdTempDir = Some(dir) + fs.deleteOnExit(dir) + } catch { + case e: IOException => + throw new RuntimeException( + "Cannot create staging directory '" + dir.toString + "': " + e.getMessage, e) + } + dir + } + + private def executionId: String = { + val rand: Random = new Random + val format = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS", Locale.US) + "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong) + } +} + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a83ad61b204a..e9bdcf00b934 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -42,7 +42,11 @@ import org.apache.spark.sql.types._ private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with CodegenFallback with Logging { + extends Expression + with HiveInspectors + with CodegenFallback + with Logging + with UserDefinedExpression { override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) @@ -119,7 +123,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp private[hive] case class HiveGenericUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with CodegenFallback with Logging { + extends Expression + with HiveInspectors + with CodegenFallback + with Logging + with UserDefinedExpression { override def nullable: Boolean = true @@ -191,7 +199,7 @@ private[hive] case class HiveGenericUDTF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Generator with HiveInspectors with CodegenFallback { + extends Generator with HiveInspectors with CodegenFallback with UserDefinedExpression { @transient protected lazy val function: GenericUDTF = { @@ -303,7 +311,9 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors { + extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] + with HiveInspectors + with UserDefinedExpression { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 3a34ec55c8b0..c76f0ebb36a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} import org.apache.spark.TaskContext -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources._ @@ -58,7 +58,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { OrcFileOperator.readSchema( - files.map(_.getPath.toUri.toString), + files.map(_.getPath.toString), Some(sparkSession.sessionState.newHadoopConf()) ) } @@ -68,7 +68,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val orcOptions = new OrcOptions(options) + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val configuration = job.getConfiguration diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index 043eb69818ba..7f94c8c57902 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -20,30 +20,34 @@ package org.apache.spark.sql.hive.orc import java.util.Locale import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.internal.SQLConf /** * Options for the ORC data source. */ -private[orc] class OrcOptions(@transient private val parameters: CaseInsensitiveMap[String]) +private[orc] class OrcOptions( + @transient private val parameters: CaseInsensitiveMap[String], + @transient private val sqlConf: SQLConf) extends Serializable { import OrcOptions._ - def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + def this(parameters: Map[String, String], sqlConf: SQLConf) = + this(CaseInsensitiveMap(parameters), sqlConf) /** - * Compression codec to use. By default snappy compression. + * Compression codec to use. * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. */ val compressionCodec: String = { - // `orc.compress` is a ORC configuration. So, here we respect this as an option but - // `compression` has higher precedence than `orc.compress`. It means if both are set, - // we will use `compression`. + // `compression`, `orc.compress`, and `spark.sql.orc.compression.codec` are + // in order of precedence from highest to lowest. val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) val codecName = parameters .get("compression") .orElse(orcCompressionConf) - .getOrElse("snappy").toLowerCase(Locale.ROOT) + .getOrElse(sqlConf.orcCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortOrcCompressionCodecNames.contains(codecName)) { val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index d9bb1f8c7edc..b6be00dbb3a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.test import java.io.File +import java.net.URI import java.util.{Set => JavaSet} import scala.collection.JavaConverters._ @@ -34,8 +35,8 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient @@ -51,11 +52,13 @@ object TestHive "TestSQLContext", new SparkConf() .set("spark.sql.test", "") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") .set("spark.sql.hive.metastore.barrierPrefixes", "org.apache.spark.sql.hive.execution.PairSerDe") .set("spark.sql.warehouse.dir", TestHiveContext.makeWarehouseDir().toURI.getPath) // SPARK-8910 - .set("spark.ui.enabled", "false"))) + .set("spark.ui.enabled", "false") + .set("spark.unsafe.exceptionOnMemoryLeak", "true"))) case class TestHiveVersion(hiveClient: HiveClient) @@ -294,23 +297,23 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE src1 (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { - sql( - "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") + "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)" + .cmd.apply() for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { - sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin.cmd.apply() } }), TestTable("srcpart1", () => { - sql( - "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)" + .cmd.apply() for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { - sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin.cmd.apply() } }), TestTable("src_thrift", () => { @@ -318,8 +321,7 @@ private[hive] class TestHiveSparkSession( import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol - sql( - s""" + s""" |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' |WITH SERDEPROPERTIES( @@ -329,13 +331,12 @@ private[hive] class TestHiveSparkSession( |STORED AS |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' - """.stripMargin) + """.stripMargin.cmd.apply() - sql( - s""" - |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' - |INTO TABLE src_thrift - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' + |INTO TABLE src_thrift + """.stripMargin.cmd.apply() }), TestTable("serdeins", s"""CREATE TABLE serdeins (key INT, value STRING) @@ -451,6 +452,8 @@ private[hive] class TestHiveSparkSession( private val loadedTables = new collection.mutable.HashSet[String] + def getLoadedTables: collection.mutable.HashSet[String] = loadedTables + def loadTestTable(name: String) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infinite mutually recursive table loading. @@ -458,7 +461,17 @@ private[hive] class TestHiveSparkSession( logDebug(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) - createCmds.foreach(_()) + + // test tables are loaded lazily, so they may be loaded in the middle a query execution which + // has already set the execution id. + if (sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) == null) { + // We don't actually have a `QueryExecution` here, use a fake one instead. + SQLExecution.withNewExecutionId(this, new QueryExecution(this, OneRowRelation())) { + createCmds.foreach(_()) + } + } else { + createCmds.foreach(_()) + } if (cacheTables) { new SQLContext(self).cacheTable(name) @@ -486,16 +499,16 @@ private[hive] class TestHiveSparkSession( } } + // Clean out the Hive warehouse between each suite + val warehouseDir = new File(new URI(sparkContext.conf.get("spark.sql.warehouse.dir")).getPath) + Utils.deleteRecursively(warehouseDir) + warehouseDir.mkdir() + sharedState.cacheManager.clearCache() loadedTables.clear() - sessionState.catalog.clearTempTables() - sessionState.catalog.tableRelationCache.invalidateAll() - + sessionState.catalog.reset() metadataHive.reset() - FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). - foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } - // HDFS root scratch dir requires the write all (733) permission. For each connecting user, // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to @@ -550,7 +563,10 @@ private[hive] class TestHiveQueryExecution( val referencedTables = describedTables ++ logical.collect { case UnresolvedRelation(tableIdent) => tableIdent.table } - val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) + val resolver = sparkSession.sessionState.conf.resolver + val referencedTestTables = sparkSession.testTables.keys.filter { testTable => + referencedTables.exists(resolver(_, testTable)) + } logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index aefc9cc77da8..636ce10da373 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.test.TestHive$; -import org.apache.spark.sql.hive.aggregate.MyDoubleSum; +import test.org.apache.spark.sql.MyDoubleSum; public class JavaDataFrameSuite { private transient SQLContext hc; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index 9bf84ab1fb7a..df7988f542b7 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -19,13 +19,17 @@ package org.apache.spark.sql.hive.test import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.SparkSession import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.client.HiveClient trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive + protected val hiveClient: HiveClient = + spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { diff --git a/sql/hive/src/test/resources/avroDecimal/decimal.avro b/sql/hive/src/test/resources/avroDecimal/decimal.avro new file mode 100755 index 000000000000..6da423f78661 Binary files /dev/null and b/sql/hive/src/test/resources/avroDecimal/decimal.avro differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index 149ce1e19511..d9cf1f361c1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -19,12 +19,29 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFrame, TimeAdd, - TimeSub, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.unsafe.types.CalendarInterval -class ExpressionSQLBuilderSuite extends SQLBuilderTest { +class ExpressionSQLBuilderSuite extends QueryTest with TestHiveSingleton { + protected def checkSQL(e: Expression, expectedSQL: String): Unit = { + val actualSQL = e.sql + try { + assert(actualSQL == expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following expression: + | + |${e.prettyName} + | + |$cause + """.stripMargin) + } + } + test("literal") { checkSQL(Literal("foo"), "'foo'") checkSQL(Literal("\"foo\""), "'\"foo\"'") @@ -98,27 +115,27 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), - s"(PARTITION BY `a` $frame)" + s"(PARTITION BY `a` ${frame.sql})" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, Nil, frame), - s"(PARTITION BY `a`, `b` $frame)" + s"(PARTITION BY `a`, `b` ${frame.sql})" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: Nil, frame), - s"(ORDER BY `a` ASC NULLS FIRST $frame)" + s"(ORDER BY `a` ASC NULLS FIRST ${frame.sql})" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: 'b.string.desc :: Nil, frame), - s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST $frame)" + s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST ${frame.sql})" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, 'c.int.asc :: 'd.string.desc :: Nil, frame), - s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST $frame)" + s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST ${frame.sql})" ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 73383ae4d411..e599d1ab1d48 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -221,7 +221,7 @@ class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingle val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] val functionIdentifier = FunctionIdentifier(functionName, database = None) val func = CatalogFunction(functionIdentifier, clazz.getName, resources = Nil) - sessionCatalog.registerFunction(func, ignoreIfExists = false) + sessionCatalog.registerFunction(func, overrideIfExists = false) } private def percentile_approx( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 939fd71b4f1e..8a7423663f28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.hive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala deleted file mode 100644 index 59cc6605a124..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ /dev/null @@ -1,723 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI -import java.util.Locale - -import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} -import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType - -class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingleton { - val parser = TestHive.sessionState.sqlParser - - private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { - parser.parsePlan(sql).collect { - case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) - }.head - } - - private def assertUnsupported(sql: String): Unit = { - val e = intercept[ParseException] { - parser.parsePlan(sql) - } - assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) - } - - private def analyzeCreateTable(sql: String): CatalogTable = { - TestHive.sessionState.analyzer.execute(parser.parsePlan(sql)).collect { - case CreateTableCommand(tableDesc, _) => tableDesc - }.head - } - - test("Test CTAS #1") { - val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |COMMENT 'This is the staging page view table' - |STORED AS RCFILE - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin - - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) - } - - test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |COMMENT 'This is the staging page view table' - |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' - | STORED AS - | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' - | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin - - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) - } - - test("Test CTAS #3") { - val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" - val (desc, exists) = extractTableDesc(s3) - assert(exists == false) - assert(desc.identifier.database == None) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.storage.locationUri == None) - assert(desc.schema.isEmpty) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc.properties == Map()) - } - - test("Test CTAS #4") { - val s4 = - """CREATE TABLE page_view - |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin - intercept[AnalysisException] { - extractTableDesc(s4) - } - } - - test("Test CTAS #5") { - val s5 = """CREATE TABLE ctas2 - | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" - | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") - | STORED AS RCFile - | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") - | AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin - val (desc, exists) = extractTableDesc(s5) - assert(exists == false) - assert(desc.identifier.database == None) - assert(desc.identifier.table == "ctas2") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.storage.locationUri == None) - assert(desc.schema.isEmpty) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.properties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) - assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) - } - - test("CTAS statement with a PARTITIONED BY clause is not allowed") { - assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + - " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") - } - - test("CTAS statement with schema") { - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") - } - - test("unsupported operations") { - intercept[ParseException] { - parser.parsePlan( - """ - |CREATE TEMPORARY TABLE ctas2 - |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" - |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") - |STORED AS RCFile - |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") - |AS SELECT key, value FROM src ORDER BY key, value - """.stripMargin) - } - intercept[ParseException] { - parser.parsePlan( - """ - |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) - |CLUSTERED BY(user_id) INTO 256 BUCKETS - |AS SELECT key, value FROM src ORDER BY key, value - """.stripMargin) - } - intercept[ParseException] { - parser.parsePlan( - """ - |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) - |SKEWED BY (key) ON (1,5,6) - |AS SELECT key, value FROM src ORDER BY key, value - """.stripMargin) - } - intercept[ParseException] { - parser.parsePlan( - """ - |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe' - |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader' - |FROM testData - """.stripMargin) - } - } - - test("Invalid interval term should throw AnalysisException") { - def assertError(sql: String, errorMessage: String): Unit = { - val e = intercept[AnalysisException] { - parser.parsePlan(sql) - } - assert(e.getMessage.contains(errorMessage)) - } - assertError("select interval '42-32' year to month", - "month 32 outside range [0, 11]") - assertError("select interval '5 49:12:15' day to second", - "hour 49 outside range [0, 23]") - assertError("select interval '.1111111111' second", - "nanosecond 1111111111 outside range") - } - - test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { - val analyzer = TestHive.sparkSession.sessionState.analyzer - val plan = analyzer.execute(parser.parsePlan( - """ - |SELECT * - |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test - |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b - """.stripMargin)) - - assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) - } - - test("transform query spec") { - val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan3 = parser.parsePlan("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - - val p = ScriptTransformation( - Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), - "func", Seq.empty, plans.table("e"), null) - - comparePlans(plan1, - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - comparePlans(plan2, - p.copy(output = Seq('c.string, 'd.string))) - comparePlans(plan3, - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - - test("use backticks in output of Script Transform") { - parser.parsePlan( - """SELECT `t`.`thing1` - |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) - |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t - """.stripMargin) - } - - test("use backticks in output of Generator") { - parser.parsePlan( - """ - |SELECT `gentab2`.`gencol2` - |FROM `default`.`src` - |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1` - |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2` - """.stripMargin) - } - - test("use escaped backticks in output of Generator") { - parser.parsePlan( - """ - |SELECT `gen``tab2`.`gen``col2` - |FROM `default`.`src` - |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1` - |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2` - """.stripMargin) - } - - test("create table - basic") { - val query = "CREATE TABLE my_table (id int, name string)" - val (desc, allowExisting) = extractTableDesc(query) - assert(!allowExisting) - assert(desc.identifier.database.isEmpty) - assert(desc.identifier.table == "my_table") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.schema == new StructType().add("id", "int").add("name", "string")) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.bucketSpec.isEmpty) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri.isEmpty) - assert(desc.storage.inputFormat == - Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc.storage.properties.isEmpty) - assert(desc.properties.isEmpty) - assert(desc.comment.isEmpty) - } - - test("create table - with database name") { - val query = "CREATE TABLE dbx.my_table (id int, name string)" - val (desc, _) = extractTableDesc(query) - assert(desc.identifier.database == Some("dbx")) - assert(desc.identifier.table == "my_table") - } - - test("create table - temporary") { - val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" - val e = intercept[ParseException] { parser.parsePlan(query) } - assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet")) - } - - test("create table - external") { - val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" - val (desc, _) = extractTableDesc(query) - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) - } - - test("create table - if not exists") { - val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" - val (_, allowExisting) = extractTableDesc(query) - assert(allowExisting) - } - - test("create table - comment") { - val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" - val (desc, _) = extractTableDesc(query) - assert(desc.comment == Some("its hot as hell below")) - } - - test("create table - partitioned columns") { - val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" - val (desc, _) = extractTableDesc(query) - assert(desc.schema == new StructType() - .add("id", "int") - .add("name", "string") - .add("month", "int")) - assert(desc.partitionColumnNames == Seq("month")) - } - - test("create table - clustered by") { - val baseQuery = "CREATE TABLE my_table (id int, name string) CLUSTERED BY(id)" - val query1 = s"$baseQuery INTO 10 BUCKETS" - val query2 = s"$baseQuery SORTED BY(id) INTO 10 BUCKETS" - val e1 = intercept[ParseException] { parser.parsePlan(query1) } - val e2 = intercept[ParseException] { parser.parsePlan(query2) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - } - - test("create table - skewed by") { - val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" - val query1 = s"$baseQuery(id) ON (1, 10, 100)" - val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" - val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" - val e1 = intercept[ParseException] { parser.parsePlan(query1) } - val e2 = intercept[ParseException] { parser.parsePlan(query2) } - val e3 = intercept[ParseException] { parser.parsePlan(query3) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - assert(e3.getMessage.contains("Operation not allowed")) - } - - test("create table - row format") { - val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" - val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" - val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" - val query3 = - s""" - |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' - |COLLECTION ITEMS TERMINATED BY 'a' - |MAP KEYS TERMINATED BY 'b' - |LINES TERMINATED BY '\n' - |NULL DEFINED AS 'c' - """.stripMargin - val (desc1, _) = extractTableDesc(query1) - val (desc2, _) = extractTableDesc(query2) - val (desc3, _) = extractTableDesc(query3) - assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc1.storage.properties.isEmpty) - assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc2.storage.properties == Map("k1" -> "v1")) - assert(desc3.storage.properties == Map( - "field.delim" -> "x", - "escape.delim" -> "y", - "serialization.format" -> "x", - "line.delim" -> "\n", - "colelction.delim" -> "a", // yes, it's a typo from Hive :) - "mapkey.delim" -> "b")) - } - - test("create table - file format") { - val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" - val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" - val query2 = s"$baseQuery ORC" - val (desc1, _) = extractTableDesc(query1) - val (desc2, _) = extractTableDesc(query2) - assert(desc1.storage.inputFormat == Some("winput")) - assert(desc1.storage.outputFormat == Some("wowput")) - assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) - assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - test("create table - storage handler") { - val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" - val query1 = s"$baseQuery 'org.papachi.StorageHandler'" - val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" - val e1 = intercept[ParseException] { parser.parsePlan(query1) } - val e2 = intercept[ParseException] { parser.parsePlan(query2) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - } - - test("create table - properties") { - val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" - val (desc, _) = extractTableDesc(query) - assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) - } - - test("create table - everything!") { - val query = - """ - |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) - |COMMENT 'no comment' - |PARTITIONED BY (month int) - |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') - |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' - |LOCATION '/path/to/mercury' - |TBLPROPERTIES ('k1'='v1', 'k2'='v2') - """.stripMargin - val (desc, allowExisting) = extractTableDesc(query) - assert(allowExisting) - assert(desc.identifier.database == Some("dbx")) - assert(desc.identifier.table == "my_table") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.schema == new StructType() - .add("id", "int") - .add("name", "string") - .add("month", "int")) - assert(desc.partitionColumnNames == Seq("month")) - assert(desc.bucketSpec.isEmpty) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) - assert(desc.storage.inputFormat == Some("winput")) - assert(desc.storage.outputFormat == Some("wowput")) - assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc.storage.properties == Map("k1" -> "v1")) - assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) - assert(desc.comment == Some("no comment")) - } - - test("create view -- basic") { - val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" - val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand] - assert(!command.allowExisting) - assert(command.name.database.isEmpty) - assert(command.name.table == "view1") - assert(command.originalText == Some("SELECT * FROM tab1")) - assert(command.userSpecifiedColumns.isEmpty) - } - - test("create view - full") { - val v1 = - """ - |CREATE OR REPLACE VIEW view1 - |(col1, col3 COMMENT 'hello') - |COMMENT 'BLABLA' - |TBLPROPERTIES('prop1Key'="prop1Val") - |AS SELECT * FROM tab1 - """.stripMargin - val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand] - assert(command.name.database.isEmpty) - assert(command.name.table == "view1") - assert(command.userSpecifiedColumns == Seq("col1" -> None, "col3" -> Some("hello"))) - assert(command.originalText == Some("SELECT * FROM tab1")) - assert(command.properties == Map("prop1Key" -> "prop1Val")) - assert(command.comment == Some("BLABLA")) - } - - test("create view -- partitioned view") { - val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart" - intercept[ParseException] { - parser.parsePlan(v1) - } - } - - test("MSCK REPAIR table") { - val sql = "MSCK REPAIR TABLE tab1" - val parsed = parser.parsePlan(sql) - val expected = AlterTableRecoverPartitionsCommand( - TableIdentifier("tab1", None), - "MSCK REPAIR TABLE") - comparePlans(parsed, expected) - } - - test("create table like") { - val v1 = "CREATE TABLE table1 LIKE table2" - val (target, source, location, exists) = parser.parsePlan(v1).collect { - case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) - }.head - assert(exists == false) - assert(target.database.isEmpty) - assert(target.table == "table1") - assert(source.database.isEmpty) - assert(source.table == "table2") - assert(location.isEmpty) - - val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" - val (target2, source2, location2, exists2) = parser.parsePlan(v2).collect { - case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) - }.head - assert(exists2) - assert(target2.database.isEmpty) - assert(target2.table == "table1") - assert(source2.database.isEmpty) - assert(source2.table == "table2") - assert(location2.isEmpty) - - val v3 = "CREATE TABLE table1 LIKE table2 LOCATION '/spark/warehouse'" - val (target3, source3, location3, exists3) = parser.parsePlan(v3).collect { - case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) - }.head - assert(!exists3) - assert(target3.database.isEmpty) - assert(target3.table == "table1") - assert(source3.database.isEmpty) - assert(source3.table == "table2") - assert(location3 == Some("/spark/warehouse")) - - val v4 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2 LOCATION '/spark/warehouse'" - val (target4, source4, location4, exists4) = parser.parsePlan(v4).collect { - case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) - }.head - assert(exists4) - assert(target4.database.isEmpty) - assert(target4.table == "table1") - assert(source4.database.isEmpty) - assert(source4.table == "table2") - assert(location4 == Some("/spark/warehouse")) - } - - test("load data") { - val v1 = "LOAD DATA INPATH 'path' INTO TABLE table1" - val (table, path, isLocal, isOverwrite, partition) = parser.parsePlan(v1).collect { - case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) - }.head - assert(table.database.isEmpty) - assert(table.table == "table1") - assert(path == "path") - assert(!isLocal) - assert(!isOverwrite) - assert(partition.isEmpty) - - val v2 = "LOAD DATA LOCAL INPATH 'path' OVERWRITE INTO TABLE table1 PARTITION(c='1', d='2')" - val (table2, path2, isLocal2, isOverwrite2, partition2) = parser.parsePlan(v2).collect { - case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) - }.head - assert(table2.database.isEmpty) - assert(table2.table == "table1") - assert(path2 == "path") - assert(isLocal2) - assert(isOverwrite2) - assert(partition2.nonEmpty) - assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2") - } - - test("Test the default fileformat for Hive-serde tables") { - withSQLConf("hive.default.fileformat" -> "orc") { - val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") - assert(exists) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - withSQLConf("hive.default.fileformat" -> "parquet") { - val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") - assert(exists) - val input = desc.storage.inputFormat - val output = desc.storage.outputFormat - val serde = desc.storage.serde - assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) - assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - } - - test("table name with schema") { - // regression test for SPARK-11778 - spark.sql("create schema usrdb") - spark.sql("create table usrdb.test(c int)") - spark.read.table("usrdb.test") - spark.sql("drop table usrdb.test") - spark.sql("drop schema usrdb") - } - - test("SPARK-15887: hive-site.xml should be loaded") { - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - assert(hiveClient.getConf("hive.in.test", "") == "true") - } - - test("create hive serde table with new syntax - basic") { - val sql = - """ - |CREATE TABLE t - |(id int, name string COMMENT 'blabla') - |USING hive - |OPTIONS (fileFormat 'parquet', my_prop 1) - |LOCATION '/tmp/file' - |COMMENT 'BLABLA' - """.stripMargin - - val table = analyzeCreateTable(sql) - assert(table.schema == new StructType() - .add("id", "int") - .add("name", "string", nullable = true, comment = "blabla")) - assert(table.provider == Some(DDLUtils.HIVE_PROVIDER)) - assert(table.storage.locationUri == Some(new URI("/tmp/file"))) - assert(table.storage.properties == Map("my_prop" -> "1")) - assert(table.comment == Some("BLABLA")) - - assert(table.storage.inputFormat == - Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) - assert(table.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - assert(table.storage.serde == - Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - - test("create hive serde table with new syntax - with partition and bucketing") { - val v1 = "CREATE TABLE t (c1 int, c2 int) USING hive PARTITIONED BY (c2)" - val table = analyzeCreateTable(v1) - assert(table.schema == new StructType().add("c1", "int").add("c2", "int")) - assert(table.partitionColumnNames == Seq("c2")) - // check the default formats - assert(table.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(table.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(table.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - - val v2 = "CREATE TABLE t (c1 int, c2 int) USING hive CLUSTERED BY (c2) INTO 4 BUCKETS" - val e2 = intercept[AnalysisException](analyzeCreateTable(v2)) - assert(e2.message.contains("Creating bucketed Hive serde table is not supported yet")) - - val v3 = - """ - |CREATE TABLE t (c1 int, c2 int) USING hive - |PARTITIONED BY (c2) - |CLUSTERED BY (c2) INTO 4 BUCKETS""".stripMargin - val e3 = intercept[AnalysisException](analyzeCreateTable(v3)) - assert(e3.message.contains("Creating bucketed Hive serde table is not supported yet")) - } - - test("create hive serde table with new syntax - Hive options error checking") { - val v1 = "CREATE TABLE t (c1 int) USING hive OPTIONS (inputFormat 'abc')" - val e1 = intercept[IllegalArgumentException](analyzeCreateTable(v1)) - assert(e1.getMessage.contains("Cannot specify only inputFormat or outputFormat")) - - val v2 = "CREATE TABLE t (c1 int) USING hive OPTIONS " + - "(fileFormat 'x', inputFormat 'a', outputFormat 'b')" - val e2 = intercept[IllegalArgumentException](analyzeCreateTable(v2)) - assert(e2.getMessage.contains( - "Cannot specify fileFormat and inputFormat/outputFormat together")) - - val v3 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', serde 'a')" - val e3 = intercept[IllegalArgumentException](analyzeCreateTable(v3)) - assert(e3.getMessage.contains("fileFormat 'parquet' already specifies a serde")) - - val v4 = "CREATE TABLE t (c1 int) USING hive OPTIONS (serde 'a', fieldDelim ' ')" - val e4 = intercept[IllegalArgumentException](analyzeCreateTable(v4)) - assert(e4.getMessage.contains("Cannot specify delimiters with a custom serde")) - - val v5 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fieldDelim ' ')" - val e5 = intercept[IllegalArgumentException](analyzeCreateTable(v5)) - assert(e5.getMessage.contains("Cannot specify delimiters without fileFormat")) - - val v6 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', fieldDelim ' ')" - val e6 = intercept[IllegalArgumentException](analyzeCreateTable(v6)) - assert(e6.getMessage.contains( - "Cannot specify delimiters as they are only compatible with fileFormat 'textfile'")) - - // The value of 'fileFormat' option is case-insensitive. - val v7 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'TEXTFILE', lineDelim ',')" - val e7 = intercept[IllegalArgumentException](analyzeCreateTable(v7)) - assert(e7.getMessage.contains("Hive data source only support newline '\\n' as line delimiter")) - - val v8 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'wrong')" - val e8 = intercept[IllegalArgumentException](analyzeCreateTable(v8)) - assert(e8.getMessage.contains("invalid fileFormat: 'wrong'")) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala deleted file mode 100644 index 705d43f1f3ab..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI - -import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils - - -class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest - with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { - - // To test `HiveExternalCatalog`, we need to read/write the raw table meta from/to hive client. - val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - - val tempDir = Utils.createTempDir().getCanonicalFile - val tempDirUri = tempDir.toURI - val tempDirStr = tempDir.getAbsolutePath - - override def beforeEach(): Unit = { - sql("CREATE DATABASE test_db") - for ((tbl, _) <- rawTablesAndExpectations) { - hiveClient.createTable(tbl, ignoreIfExists = false) - } - } - - override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) - hiveClient.dropDatabase("test_db", ignoreIfNotExists = false, cascade = true) - } - - private def getTableMetadata(tableName: String): CatalogTable = { - spark.sharedState.externalCatalog.getTable("test_db", tableName) - } - - private def defaultTableURI(tableName: String): URI = { - spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db"))) - } - - // Raw table metadata that are dumped from tables created by Spark 2.0. Note that, all spark - // versions prior to 2.1 would generate almost same raw table metadata for a specific table. - val simpleSchema = new StructType().add("i", "int") - val partitionedSchema = new StructType().add("i", "int").add("j", "int") - - lazy val hiveTable = CatalogTable( - identifier = TableIdentifier("tbl1", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - schema = simpleSchema) - - lazy val externalHiveTable = CatalogTable( - identifier = TableIdentifier("tbl2", Some("test_db")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - locationUri = Some(tempDirUri), - inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - schema = simpleSchema) - - lazy val partitionedHiveTable = CatalogTable( - identifier = TableIdentifier("tbl3", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - schema = partitionedSchema, - partitionColumnNames = Seq("j")) - - - val simpleSchemaJson = - """ - |{ - | "type": "struct", - | "fields": [{ - | "name": "i", - | "type": "integer", - | "nullable": true, - | "metadata": {} - | }] - |} - """.stripMargin - - val partitionedSchemaJson = - """ - |{ - | "type": "struct", - | "fields": [{ - | "name": "i", - | "type": "integer", - | "nullable": true, - | "metadata": {} - | }, - | { - | "name": "j", - | "type": "integer", - | "nullable": true, - | "metadata": {} - | }] - |} - """.stripMargin - - lazy val dataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl4", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> defaultTableURI("tbl4").toString)), - schema = new StructType(), - provider = Some("json"), - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> simpleSchemaJson)) - - lazy val hiveCompatibleDataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl5", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> defaultTableURI("tbl5").toString)), - schema = simpleSchema, - provider = Some("parquet"), - properties = Map( - "spark.sql.sources.provider" -> "parquet", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> simpleSchemaJson)) - - lazy val partitionedDataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl6", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> defaultTableURI("tbl6").toString)), - schema = new StructType(), - provider = Some("json"), - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> partitionedSchemaJson, - "spark.sql.sources.schema.numPartCols" -> "1", - "spark.sql.sources.schema.partCol.0" -> "j")) - - lazy val externalDataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl7", Some("test_db")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - locationUri = Some(new URI(defaultTableURI("tbl7") + "-__PLACEHOLDER__")), - properties = Map("path" -> tempDirStr)), - schema = new StructType(), - provider = Some("json"), - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> simpleSchemaJson)) - - lazy val hiveCompatibleExternalDataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl8", Some("test_db")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - locationUri = Some(tempDirUri), - properties = Map("path" -> tempDirStr)), - schema = simpleSchema, - properties = Map( - "spark.sql.sources.provider" -> "parquet", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> simpleSchemaJson)) - - lazy val dataSourceTableWithoutSchema = CatalogTable( - identifier = TableIdentifier("tbl9", Some("test_db")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - locationUri = Some(new URI(defaultTableURI("tbl9") + "-__PLACEHOLDER__")), - properties = Map("path" -> tempDirStr)), - schema = new StructType(), - provider = Some("json"), - properties = Map("spark.sql.sources.provider" -> "json")) - - // A list of all raw tables we want to test, with their expected schema. - lazy val rawTablesAndExpectations = Seq( - hiveTable -> simpleSchema, - externalHiveTable -> simpleSchema, - partitionedHiveTable -> partitionedSchema, - dataSourceTable -> simpleSchema, - hiveCompatibleDataSourceTable -> simpleSchema, - partitionedDataSourceTable -> partitionedSchema, - externalDataSourceTable -> simpleSchema, - hiveCompatibleExternalDataSourceTable -> simpleSchema, - dataSourceTableWithoutSchema -> new StructType()) - - test("make sure we can read table created by old version of Spark") { - for ((tbl, expectedSchema) <- rawTablesAndExpectations) { - val readBack = getTableMetadata(tbl.identifier.table) - assert(readBack.schema.sameType(expectedSchema)) - - if (tbl.tableType == CatalogTableType.EXTERNAL) { - // trim the URI prefix - val tableLocation = readBack.storage.locationUri.get.getPath - val expectedLocation = tempDir.toURI.getPath.stripSuffix("/") - assert(tableLocation == expectedLocation) - } - } - } - - test("make sure we can alter table location created by old version of Spark") { - withTempDir { dir => - for ((tbl, _) <- rawTablesAndExpectations if tbl.tableType == CatalogTableType.EXTERNAL) { - val path = dir.toURI.toString.stripSuffix("/") - sql(s"ALTER TABLE ${tbl.identifier} SET LOCATION '$path'") - - val readBack = getTableMetadata(tbl.identifier.table) - - // trim the URI prefix - val actualTableLocation = readBack.storage.locationUri.get.getPath - val expected = dir.toURI.getPath.stripSuffix("/") - assert(actualTableLocation == expected) - } - } - } - - test("make sure we can rename table created by old version of Spark") { - for ((tbl, expectedSchema) <- rawTablesAndExpectations) { - val newName = tbl.identifier.table + "_renamed" - sql(s"ALTER TABLE ${tbl.identifier} RENAME TO $newName") - - val readBack = getTableMetadata(newName) - assert(readBack.schema.sameType(expectedSchema)) - - // trim the URI prefix - val actualTableLocation = readBack.storage.locationUri.get.getPath - val expectedLocation = if (tbl.tableType == CatalogTableType.EXTERNAL) { - tempDir.toURI.getPath.stripSuffix("/") - } else { - // trim the URI prefix - defaultTableURI(newName).getPath - } - assert(actualTableLocation == expectedLocation) - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index bd54c043c6ec..d43534d5914d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -63,4 +63,30 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { assert(!rawTable.properties.contains(HiveExternalCatalog.DATASOURCE_PROVIDER)) assert(DDLUtils.isHiveTable(externalCatalog.getTable("db1", "hive_tbl"))) } + + Seq("parquet", "hive").foreach { format => + test(s"Partition columns should be put at the end of table schema for the format $format") { + val catalog = newBasicCatalog() + val newSchema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string") + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType() + .add("col1", "int") + .add("partCol1", "int") + .add("partCol2", "string") + .add("col2", "string"), + provider = Some(format), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val restoredTable = externalCatalog.getTable("db1", "tbl") + assert(restoredTable.schema == newSchema) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala new file mode 100644 index 000000000000..305f5b533d59 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File +import java.nio.file.Files + +import org.apache.spark.TestUtils +import org.apache.spark.sql.{QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +/** + * Test HiveExternalCatalog backward compatibility. + * + * Note that, this test suite will automatically download spark binary packages of different + * versions to a local directory `/tmp/spark-test`. If there is already a spark folder with + * expected version under this local directory, e.g. `/tmp/spark-test/spark-2.0.3`, we will skip the + * downloading for this spark version. + */ +class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { + private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") + private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") + // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to + // avoid downloading Spark of different versions in each run. + private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") + private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + + override def afterAll(): Unit = { + Utils.deleteRecursively(wareHousePath) + Utils.deleteRecursively(tmpDataDir) + Utils.deleteRecursively(sparkTestingDir) + super.afterAll() + } + + private def downloadSpark(version: String): Unit = { + import scala.sys.process._ + + val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz" + + Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! + + val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + + Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + + Seq("rm", downloaded).! + } + + private def genDataDir(name: String): String = { + new File(tmpDataDir, name).getCanonicalPath + } + + override def beforeAll(): Unit = { + super.beforeAll() + + val tempPyFile = File.createTempFile("test", ".py") + Files.write(tempPyFile.toPath, + s""" + |from pyspark.sql import SparkSession + | + |spark = SparkSession.builder.enableHiveSupport().getOrCreate() + |version_index = spark.conf.get("spark.sql.test.version.index", None) + | + |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index)) + | + |spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\ + | " using parquet as select 1 i") + | + |json_file = "${genDataDir("json_")}" + str(version_index) + |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file) + |spark.sql("create table external_data_source_tbl_" + version_index + \\ + | "(i int) using json options (path '{}')".format(json_file)) + | + |parquet_file = "${genDataDir("parquet_")}" + str(version_index) + |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file) + |spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\ + | "(i int) using parquet options (path '{}')".format(parquet_file)) + | + |json_file2 = "${genDataDir("json2_")}" + str(version_index) + |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2) + |spark.sql("create table external_table_without_schema_" + version_index + \\ + | " using json options (path '{}')".format(json_file2)) + | + |spark.sql("create view v_{} as select 1 i".format(version_index)) + """.stripMargin.getBytes("utf8")) + + PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => + val sparkHome = new File(sparkTestingDir, s"spark-$version") + if (!sparkHome.exists()) { + downloadSpark(version) + } + + val args = Seq( + "--name", "prepare testing tables", + "--master", "local[2]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}", + "--conf", s"spark.sql.test.version.index=$index", + "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", + tempPyFile.getCanonicalPath) + runSparkSubmit(args, Some(sparkHome.getCanonicalPath)) + } + + tempPyFile.delete() + } + + test("backward compatibility") { + val args = Seq( + "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"), + "--name", "HiveExternalCatalog backward compatibility test", + "--master", "local[2]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}", + "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", + unusedJar.toString) + runSparkSubmit(args) + } +} + +object PROCESS_TABLES extends QueryTest with SQLTestUtils { + // Tests the latest version of every release line. + val testingVersions = Seq("2.0.2", "2.1.1", "2.2.0") + + protected var spark: SparkSession = _ + + def main(args: Array[String]): Unit = { + val session = SparkSession.builder() + .enableHiveSupport() + .getOrCreate() + spark = session + + testingVersions.indices.foreach { index => + Seq( + s"data_source_tbl_$index", + s"hive_compatible_data_source_tbl_$index", + s"external_data_source_tbl_$index", + s"hive_compatible_external_data_source_tbl_$index", + s"external_table_without_schema_$index").foreach { tbl => + val tableMeta = spark.sharedState.externalCatalog.getTable("default", tbl) + + // make sure we can insert and query these tables. + session.sql(s"insert into $tbl select 2") + checkAnswer(session.sql(s"select * from $tbl"), Row(1) :: Row(2) :: Nil) + checkAnswer(session.sql(s"select i from $tbl where i > 1"), Row(2)) + + // make sure we can rename table. + val newName = tbl + "_renamed" + sql(s"ALTER TABLE $tbl RENAME TO $newName") + val readBack = spark.sharedState.externalCatalog.getTable("default", newName) + + val actualTableLocation = readBack.storage.locationUri.get.getPath + val expectedLocation = if (tableMeta.tableType == CatalogTableType.EXTERNAL) { + tableMeta.storage.locationUri.get.getPath + } else { + spark.sessionState.catalog.defaultTablePath(TableIdentifier(newName, None)).getPath + } + assert(actualTableLocation == expectedLocation) + + // make sure we can alter table location. + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + sql(s"ALTER TABLE ${tbl}_renamed SET LOCATION '$path'") + val readBack = spark.sharedState.externalCatalog.getTable("default", tbl + "_renamed") + val actualTableLocation = readBack.storage.locationUri.get.getPath + val expected = dir.toURI.getPath.stripSuffix("/") + assert(actualTableLocation == expected) + } + } + + // test permanent view + checkAnswer(sql(s"select i from v_$index"), Row(1)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3de1f4aeb74d..c300660458fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -28,11 +28,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { @@ -90,7 +90,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { Literal(0.asInstanceOf[Double]) :: Literal("0") :: Literal(java.sql.Date.valueOf("2014-09-23")) :: - Literal(Decimal(BigDecimal(123.123))) :: + Literal(Decimal(BigDecimal("123.123"))) :: Literal(new java.sql.Timestamp(123123)) :: Literal(Array[Byte](1, 2, 3)) :: Literal.create(Seq[Int](1, 2, 3), ArrayType(IntegerType)) :: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index d8fd68b63d1e..18137e7ea1d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} -import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { import spark.implicits._ @@ -67,6 +67,73 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { assert(aliases.size == 1) } } + + test("Validate catalog metadata for supported data types") { + withTable("t") { + sql( + """ + |CREATE TABLE t ( + |c1 boolean, + |c2 tinyint, + |c3 smallint, + |c4 short, + |c5 bigint, + |c6 long, + |c7 float, + |c8 double, + |c9 date, + |c10 timestamp, + |c11 string, + |c12 char(10), + |c13 varchar(10), + |c14 binary, + |c15 decimal, + |c16 decimal(10), + |c17 decimal(10,2), + |c18 array, + |c19 array, + |c20 array, + |c21 map, + |c22 map, + |c23 struct, + |c24 struct + |) + """.stripMargin) + + val schema = hiveClient.getTable("default", "t").schema + val expectedSchema = new StructType() + .add("c1", "boolean") + .add("c2", "tinyint") + .add("c3", "smallint") + .add("c4", "short") + .add("c5", "bigint") + .add("c6", "long") + .add("c7", "float") + .add("c8", "double") + .add("c9", "date") + .add("c10", "timestamp") + .add("c11", "string") + .add("c12", "string", true, + new MetadataBuilder().putString(HIVE_TYPE_STRING, "char(10)").build()) + .add("c13", "string", true, + new MetadataBuilder().putString(HIVE_TYPE_STRING, "varchar(10)").build()) + .add("c14", "binary") + .add("c15", "decimal") + .add("c16", "decimal(10)") + .add("c17", "decimal(10,2)") + .add("c18", "array") + .add("c19", "array") + .add("c20", "array", true, + new MetadataBuilder().putString(HIVE_TYPE_STRING, "array").build()) + .add("c21", "map") + .add("c22", "map", true, + new MetadataBuilder().putString(HIVE_TYPE_STRING, "map").build()) + .add("c23", "struct") + .add("c24", "struct", true, + new MetadataBuilder().putString(HIVE_TYPE_STRING, "struct").build()) + assert(schema == expectedSchema) + } + } } class DataSourceWithHiveMetastoreCatalogSuite @@ -80,17 +147,17 @@ class DataSourceWithHiveMetastoreCatalogSuite ).coalesce(1) Seq( - "parquet" -> ( + "parquet" -> (( "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" - ), + )), - "orc" -> ( + "orc" -> (( "org.apache.hadoop.hive.ql.io.orc.OrcInputFormat", "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat", "org.apache.hadoop.hive.ql.io.orc.OrcSerde" - ) + )) ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { withTable("t") { @@ -180,5 +247,6 @@ class DataSourceWithHiveMetastoreCatalogSuite } } } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index 319d02613f00..f2d27671094d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -23,10 +23,10 @@ import scala.util.Random import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.datasources.FileStatusCache -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode.{Value => InferenceMode, _} @@ -46,7 +46,7 @@ class HiveSchemaInferenceSuite override def afterEach(): Unit = { super.afterEach() - spark.sessionState.catalog.tableRelationCache.invalidateAll() + spark.sessionState.catalog.invalidateAllCachedTables() FileStatusCache.resetForTesting() } @@ -71,7 +71,7 @@ class HiveSchemaInferenceSuite name = field, dataType = LongType, nullable = true, - metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "bigint").build()) + metadata = Metadata.empty) } // and all partition columns as ints val partitionStructFields = partitionCols.map { field => @@ -80,7 +80,7 @@ class HiveSchemaInferenceSuite name = field.toLowerCase, dataType = IntegerType, nullable = true, - metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "int").build()) + metadata = Metadata.empty) } val schema = StructType(structFields ++ partitionStructFields) @@ -104,7 +104,7 @@ class HiveSchemaInferenceSuite identifier = TableIdentifier(table = TEST_TABLE_NAME, database = Option(DATABASE)), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat( - locationUri = Option(new java.net.URI(dir.getAbsolutePath)), + locationUri = Option(dir.toURI), inputFormat = serde.inputFormat, outputFormat = serde.outputFormat, serde = serde.serde, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 5f15a705a2e9..21b3e281490c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -18,17 +18,11 @@ package org.apache.spark.sql.hive import java.io.{BufferedWriter, File, FileWriter} -import java.sql.Timestamp -import java.util.Date -import scala.collection.mutable.ArrayBuffer import scala.tools.nsc.Properties import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.scalatest.concurrent.Timeouts -import org.scalatest.exceptions.TestFailedDueToTimeoutException -import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging @@ -38,7 +32,6 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} -import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -46,11 +39,10 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} * This suite tests spark-submit with applications using HiveContext. */ class HiveSparkSubmitSuite - extends SparkFunSuite + extends SparkSubmitTestUtils with Matchers with BeforeAndAfterEach - with ResetSystemProperties - with Timeouts { + with ResetSystemProperties { // TODO: rewrite these or mark them as slow tests to be run sparingly @@ -151,7 +143,7 @@ class HiveSparkSubmitSuite // the HiveContext code mistakenly overrides the class loader that contains user classes. // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. val version = Properties.versionNumberString match { - case v if v.startsWith("2.10") || v.startsWith("2.11") => v.substring(0, 4) + case v if v.startsWith("2.12") || v.startsWith("2.11") => v.substring(0, 4) case x => throw new Exception(s"Unsupported Scala Version: $x") } val jarDir = getTestResourcePath("regression-test-SPARK-8489") @@ -333,71 +325,6 @@ class HiveSparkSubmitSuite unusedJar.toString) runSparkSubmit(argsForShowTables) } - - // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - // This is copied from org.apache.spark.deploy.SparkSubmitSuite - private def runSparkSubmit(args: Seq[String]): Unit = { - val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val history = ArrayBuffer.empty[String] - val sparkSubmit = if (Utils.isWindows) { - // On Windows, `ProcessBuilder.directory` does not change the current working directory. - new File("..\\..\\bin\\spark-submit.cmd").getAbsolutePath - } else { - "./bin/spark-submit" - } - val commands = Seq(sparkSubmit) ++ args - val commandLine = commands.mkString("'", "' '", "'") - - val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) - val env = builder.environment() - env.put("SPARK_TESTING", "1") - env.put("SPARK_HOME", sparkHome) - - def captureOutput(source: String)(line: String): Unit = { - // This test suite has some weird behaviors when executed on Jenkins: - // - // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a - // timestamp to provide more diagnosis information. - // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print - // them out for debugging purposes. - val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" - // scalastyle:off println - println(logLine) - // scalastyle:on println - history += logLine - } - - val process = builder.start() - new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() - new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() - - try { - val exitCode = failAfter(300.seconds) { process.waitFor() } - if (exitCode != 0) { - // include logs in output. Note that logging is async and may not have completed - // at the time this exception is raised - Thread.sleep(1000) - val historyLog = history.mkString("\n") - fail { - s"""spark-submit returned with exit code $exitCode. - |Command line: $commandLine - | - |$historyLog - """.stripMargin - } - } - } catch { - case to: TestFailedDueToTimeoutException => - val historyLog = history.mkString("\n") - fail(s"Timeout of $commandLine" + - s" See the log4j logs for more detail." + - s"\n$historyLog", to) - case t: Throwable => throw t - } finally { - // Ensure we still kill the process in case it timed out - process.destroy() - } - } } object SetMetastoreURLTest extends Logging { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index 667a7ddd8bb6..fdbfcf1a6844 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.QueryTest class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -33,4 +33,13 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(conf(ConfVars.METASTORE_END_FUNCTION_LISTENERS.varname) === "") } } + + test("newTemporaryConfiguration respect spark.hadoop.foo=bar in SparkConf") { + sys.props.put("spark.hadoop.foo", "bar") + Seq(true, false) foreach { useInMemoryDerby => + val hiveConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby) + assert(!hiveConf.contains("spark.hadoop.foo")) + assert(hiveConf("foo") === "bar") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala similarity index 54% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index d6999af84eac..aa5cae33f5cd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -34,7 +34,7 @@ case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter +class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter with SQLTestUtils { import spark.implicits._ @@ -50,47 +50,53 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } test("insertInto() HiveTable") { - sql("CREATE TABLE createAndInsertTest (key int, value string)") - - // Add some data. - testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") - - // Make sure the table has also been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq - ) - - // Add more data. - testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") - - // Make sure the table has been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertTest"), - testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq - ) - - // Now overwrite. - testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") - - // Make sure the registered table has also been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq - ) + withTable("createAndInsertTest") { + sql("CREATE TABLE createAndInsertTest (key int, value string)") + + // Add some data. + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") + + // Make sure the table has also been updated. + checkAnswer( + sql("SELECT * FROM createAndInsertTest"), + testData.collect().toSeq + ) + + // Add more data. + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") + + // Make sure the table has been updated. + checkAnswer( + sql("SELECT * FROM createAndInsertTest"), + testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq + ) + + // Now overwrite. + testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") + + // Make sure the registered table has also been updated. + checkAnswer( + sql("SELECT * FROM createAndInsertTest"), + testData.collect().toSeq + ) + } } test("Double create fails when allowExisting = false") { - sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - - intercept[AnalysisException] { + withTable("doubleCreateAndInsertTest") { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + + intercept[AnalysisException] { + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + } } } test("Double create does not fail when allowExisting = true") { - sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") + withTable("doubleCreateAndInsertTest") { + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") + } } test("SPARK-4052: scala.collection.Map as value type of MapType") { @@ -166,72 +172,54 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE tmp_table") } - test("INSERT OVERWRITE - partition IF NOT EXISTS") { - withTempDir { tmpDir => - val table = "table_with_partition" - withTable(table) { - val selQuery = s"select c1, p1, p2 from $table" - sql( - s""" - |CREATE TABLE $table(c1 string) - |PARTITIONED by (p1 string,p2 string) - |location '${tmpDir.toURI.toString}' - """.stripMargin) - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') - |SELECT 'blarr' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr", "a", "b")) - - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') - |SELECT 'blarr2' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr2", "a", "b")) + testPartitionedTable("INSERT OVERWRITE - partition IF NOT EXISTS") { tableName => + val selQuery = s"select a, b, c, d from $tableName" + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) + |SELECT 1, 4 + """.stripMargin) + checkAnswer(sql(selQuery), Row(1, 2, 3, 4)) - var e = intercept[AnalysisException] { - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2) IF NOT EXISTS - |SELECT 'blarr3', 'newPartition' - """.stripMargin) - } - assert(e.getMessage.contains( - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]")) + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) + |SELECT 5, 6 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) + + val e = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c) IF NOT EXISTS + |SELECT 7, 8, 3 + """.stripMargin) + } + assert(e.getMessage.contains( + "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [c]")) - e = intercept[AnalysisException] { - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2) IF NOT EXISTS - |SELECT 'blarr3', 'b' - """.stripMargin) - } - assert(e.getMessage.contains( - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]")) + // If the partition already exists, the insert will overwrite the data + // unless users specify IF NOT EXISTS + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) IF NOT EXISTS + |SELECT 9, 10 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) - // If the partition already exists, the insert will overwrite the data - // unless users specify IF NOT EXISTS - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') IF NOT EXISTS - |SELECT 'blarr3' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr2", "a", "b")) - } - } + // ADD PARTITION has the same effect, even if no actual data is inserted. + sql(s"ALTER TABLE $tableName ADD PARTITION (b=21, c=31)") + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=21, c=31) IF NOT EXISTS + |SELECT 20, 24 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) } test("Insert ArrayType.containsNull == false") { @@ -286,29 +274,33 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef test("Test partition mode = strict") { withSQLConf(("hive.exec.dynamic.partition.mode", "strict")) { - sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") - val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + withTable("partitioned") { + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) .toDF("id", "data", "part") - intercept[SparkException] { - data.write.insertInto("partitioned") + intercept[SparkException] { + data.write.insertInto("partitioned") + } } } } test("Detect table partitioning") { withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { - sql("CREATE TABLE source (id bigint, data string, part string)") - val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")).toDF() + withTable("source", "partitioned") { + sql("CREATE TABLE source (id bigint, data string, part string)") + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")).toDF() - data.write.insertInto("source") - checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) - sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") - // this will pick up the output partitioning from the table definition - spark.table("source").write.insertInto("partitioned") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + // this will pick up the output partitioning from the table definition + spark.table("source").write.insertInto("partitioned") - checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq) + checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq) + } } } @@ -479,19 +471,261 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef testPartitionedTable("insertInto() should reject missing columns") { tableName => - sql("CREATE TABLE t (a INT, b INT)") + withTable("t") { + sql("CREATE TABLE t (a INT, b INT)") - intercept[AnalysisException] { - spark.table("t").write.insertInto(tableName) + intercept[AnalysisException] { + spark.table("t").write.insertInto(tableName) + } } } testPartitionedTable("insertInto() should reject extra columns") { tableName => - sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)") + withTable("t") { + sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)") - intercept[AnalysisException] { - spark.table("t").write.insertInto(tableName) + intercept[AnalysisException] { + spark.table("t").write.insertInto(tableName) + } } } + + private def testBucketedTable(testName: String)(f: String => Unit): Unit = { + test(s"Hive SerDe table - $testName") { + val hiveTable = "hive_table" + + withTable(hiveTable) { + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + sql( + s""" + |CREATE TABLE $hiveTable (a INT, d INT) + |PARTITIONED BY (b INT, c INT) + |CLUSTERED BY(a) + |SORTED BY(a, d) INTO 256 BUCKETS + |STORED AS TEXTFILE + """.stripMargin) + f(hiveTable) + } + } + } + } + + testBucketedTable("INSERT should NOT fail if strict bucketing is NOT enforced") { + tableName => + withSQLConf("hive.enforce.bucketing" -> "false", "hive.enforce.sorting" -> "false") { + sql(s"INSERT INTO TABLE $tableName SELECT 1, 4, 2 AS c, 3 AS b") + checkAnswer(sql(s"SELECT a, b, c, d FROM $tableName"), Row(1, 2, 3, 4)) + } + } + + testBucketedTable("INSERT should fail if strict bucketing / sorting is enforced") { + tableName => + withSQLConf("hive.enforce.bucketing" -> "true", "hive.enforce.sorting" -> "false") { + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") + } + } + withSQLConf("hive.enforce.bucketing" -> "false", "hive.enforce.sorting" -> "true") { + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") + } + } + withSQLConf("hive.enforce.bucketing" -> "true", "hive.enforce.sorting" -> "true") { + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") + } + } + } + + test("SPARK-20594: hive.exec.stagingdir was deleted by Hive") { + // Set hive.exec.stagingdir under the table directory without start with ".". + withSQLConf("hive.exec.stagingdir" -> "./test") { + withTable("test_table") { + sql("CREATE TABLE test_table (key int)") + sql("INSERT OVERWRITE TABLE test_table SELECT 1") + checkAnswer(sql("SELECT * FROM test_table"), Row(1)) + } + } + } + + test("insert overwrite to dir from hive metastore table") { + withTempDir { dir => + val path = dir.toURI.getPath + + sql(s"INSERT OVERWRITE LOCAL DIRECTORY '${path}' SELECT * FROM src where key < 10") + + sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '${path}' + |STORED AS orc + |SELECT * FROM src where key < 10 + """.stripMargin) + + // use orc data source to check the data of path is right. + withTempView("orc_source") { + sql( + s""" + |CREATE TEMPORARY VIEW orc_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${dir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("select * from orc_source"), + sql("select * from src where key < 10")) + } + } + } + + test("insert overwrite to local dir from temp table") { + withTempView("test_insert_table") { + spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table") + + withTempDir { dir => + val path = dir.toURI.getPath + + sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '${path}' + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |SELECT * FROM test_insert_table + """.stripMargin) + + sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '${path}' + |STORED AS orc + |SELECT * FROM test_insert_table + """.stripMargin) + + // use orc data source to check the data of path is right. + checkAnswer( + spark.read.orc(dir.getCanonicalPath), + sql("select * from test_insert_table")) + } + } + } + + test("insert overwrite to dir from temp table") { + withTempView("test_insert_table") { + spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table") + + withTempDir { dir => + val pathUri = dir.toURI + + sql( + s""" + |INSERT OVERWRITE DIRECTORY '${pathUri}' + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |SELECT * FROM test_insert_table + """.stripMargin) + + sql( + s""" + |INSERT OVERWRITE DIRECTORY '${pathUri}' + |STORED AS orc + |SELECT * FROM test_insert_table + """.stripMargin) + + // use orc data source to check the data of path is right. + checkAnswer( + spark.read.orc(dir.getCanonicalPath), + sql("select * from test_insert_table")) + } + } + } + + test("multi insert overwrite to dir") { + withTempView("test_insert_table") { + spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table") + + withTempDir { dir => + val pathUri = dir.toURI + + withTempDir { dir2 => + val pathUri2 = dir2.toURI + + sql( + s""" + |FROM test_insert_table + |INSERT OVERWRITE DIRECTORY '${pathUri}' + |STORED AS orc + |SELECT id + |INSERT OVERWRITE DIRECTORY '${pathUri2}' + |STORED AS orc + |SELECT * + """.stripMargin) + + // use orc data source to check the data of path is right. + checkAnswer( + spark.read.orc(dir.getCanonicalPath), + sql("select id from test_insert_table")) + + checkAnswer( + spark.read.orc(dir2.getCanonicalPath), + sql("select * from test_insert_table")) + } + } + } + } + + test("insert overwrite to dir to illegal path") { + withTempView("test_insert_table") { + spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table") + + val e = intercept[IllegalArgumentException] { + sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY 'abc://a' + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |SELECT * FROM test_insert_table + """.stripMargin) + }.getMessage + + assert(e.contains("Wrong FS: abc://a, expected: file:///")) + } + } + + test("insert overwrite to dir with mixed syntax") { + withTempView("test_insert_table") { + spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table") + + val e = intercept[ParseException] { + sql( + s""" + |INSERT OVERWRITE DIRECTORY 'file://tmp' + |USING json + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |SELECT * FROM test_insert_table + """.stripMargin) + }.getMessage + + assert(e.contains("mismatched input 'ROW'")) + } + } + + test("insert overwrite to dir with multi inserts") { + withTempView("test_insert_table") { + spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table") + + val e = intercept[ParseException] { + sql( + s""" + |INSERT OVERWRITE DIRECTORY 'file://tmp2' + |USING json + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |SELECT * FROM test_insert_table + |INSERT OVERWRITE DIRECTORY 'file://tmp2' + |USING json + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |SELECT * FROM test_insert_table + """.stripMargin) + }.getMessage + + assert(e.contains("mismatched input 'ROW'")) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 15ba61646d03..32db22e704b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.hive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index b55469481557..f5d41c91270a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -52,11 +52,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } - // To test `HiveExternalCatalog`, we need to read the raw table metadata(schema, partition - // columns and bucket specification are still in table properties) from hive client. - private def hiveClient: HiveClient = - sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - test("persistent JSON table") { withTable("jsonTable") { sql( @@ -588,7 +583,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv Row(3) :: Row(4) :: Nil) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: HadoopFsRelation, _, _) => // OK + case LogicalRelation(p: HadoopFsRelation, _, _, _) => // OK case _ => fail(s"test_parquet_ctas should have be converted to ${classOf[HadoopFsRelation]}") } @@ -998,7 +993,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv spark.sql("""drop database if exists testdb8156 CASCADE""") } - test("skip hive metadata on table creation") { withTempDir { tempPath => val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) @@ -1350,6 +1344,17 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + Seq("orc", "parquet", "csv", "json", "text").foreach { format => + test(s"SPARK-22146: read files containing special characters using $format") { + val nameWithSpecialChars = s"sp&cial%chars" + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + spark.read.format(format).load(tmpFile) + } + } + } + private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { @@ -1359,30 +1364,4 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sparkSession.sparkContext.conf.set(DEBUG_MODE, previousValue) } } - - test("SPARK-18464: support old table which doesn't store schema in table properties") { - withTable("old") { - withTempPath { path => - Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) - val tableDesc = CatalogTable( - identifier = TableIdentifier("old", Some("default")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> path.getAbsolutePath) - ), - schema = new StructType(), - provider = Some("parquet"), - properties = Map( - HiveExternalCatalog.DATASOURCE_PROVIDER -> "parquet")) - hiveClient.createTable(tableDesc, ignoreIfExists = false) - - checkAnswer(spark.table("old"), Row(1, "a")) - - val expectedSchema = StructType(Seq( - StructField("i", IntegerType, nullable = true), - StructField("j", StringType, nullable = true))) - assert(table("old").schema === expectedSchema) - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 4aea6d14efb0..9060ce2e0eb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.hive -import java.net.URI - import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 50506197b313..54d3962a46b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -23,8 +23,8 @@ import java.util.concurrent.{Executors, TimeUnit} import org.scalatest.BeforeAndAfterEach import org.apache.spark.metrics.source.HiveCatalogMetrics -import org.apache.spark.sql.execution.datasources.FileStatusCache import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.datasources.FileStatusCache import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 43b6bf5feeb6..b2dc401ce1ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.File +import java.sql.Timestamp import com.google.common.io.Files import org.apache.hadoop.fs.FileSystem @@ -68,4 +69,20 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl sql("DROP TABLE IF EXISTS createAndInsertTest") } } + + test("SPARK-21739: Cast expression should initialize timezoneId") { + withTable("table_with_timestamp_partition") { + sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)") + sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " + + "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)") + + // test for Cast expression in TableReader + checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"), + Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000")))) + + // test for Cast expression in HiveTableScanExec + checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " + + "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index 4bfab0f9cfbf..fad81c7e9474 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -247,21 +247,16 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } } - test("hive bucketing is not supported") { + test("hive bucketing is supported") { withTable("t1") { - createRawHiveTable( + sql( s"""CREATE TABLE t1 (a INT, b STRING) |CLUSTERED BY (a) |SORTED BY (b) |INTO 2 BUCKETS """.stripMargin ) - - val cause = intercept[AnalysisException] { - sql("SHOW CREATE TABLE t1") - } - - assert(cause.getMessage.contains(" - bucketing")) + checkCreateTable("t1") } } @@ -330,26 +325,20 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing "last_modified_by", "last_modified_time", "Owner:", - "COLUMN_STATS_ACCURATE", // The following are hive specific schema parameters which we do not need to match exactly. - "numFiles", - "numRows", - "rawDataSize", - "totalSize", "totalNumberFiles", "maxFileSize", - "minFileSize", - // EXTERNAL is not non-deterministic, but it is filtered out for external tables. - "EXTERNAL" + "minFileSize" ) table.copy( createTime = 0L, lastAccessTime = 0L, - properties = table.properties.filterKeys(!nondeterministicProps.contains(_)) + properties = table.properties.filterKeys(!nondeterministicProps.contains(_)), + stats = None, + ignoredProperties = Map.empty ) } - assert(normalize(actual) == normalize(expected)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala new file mode 100644 index 000000000000..ede44df4afe1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File +import java.sql.Timestamp +import java.util.Date + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.TimeLimits +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.util.Utils + +trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits { + + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + // This is copied from org.apache.spark.deploy.SparkSubmitSuite + protected def runSparkSubmit(args: Seq[String], sparkHomeOpt: Option[String] = None): Unit = { + val sparkHome = sparkHomeOpt.getOrElse( + sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))) + val history = ArrayBuffer.empty[String] + val sparkSubmit = if (Utils.isWindows) { + // On Windows, `ProcessBuilder.directory` does not change the current working directory. + new File("..\\..\\bin\\spark-submit.cmd").getAbsolutePath + } else { + "./bin/spark-submit" + } + val commands = Seq(sparkSubmit) ++ args + val commandLine = commands.mkString("'", "' '", "'") + + val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) + val env = builder.environment() + env.put("SPARK_TESTING", "1") + env.put("SPARK_HOME", sparkHome) + + def captureOutput(source: String)(line: String): Unit = { + // This test suite has some weird behaviors when executed on Jenkins: + // + // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a + // timestamp to provide more diagnosis information. + // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print + // them out for debugging purposes. + val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" + // scalastyle:off println + println(logLine) + // scalastyle:on println + history += logLine + } + + val process = builder.start() + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() + + try { + val exitCode = failAfter(300.seconds) { process.waitFor() } + if (exitCode != 0) { + // include logs in output. Note that logging is async and may not have completed + // at the time this exception is raised + Thread.sleep(1000) + val historyLog = history.mkString("\n") + fail { + s"""spark-submit returned with exit code $exitCode. + |Command line: $commandLine + | + |$historyLog + """.stripMargin + } + } + } catch { + case to: TestFailedDueToTimeoutException => + val historyLog = history.mkString("\n") + fail(s"Timeout of $commandLine" + + s" See the log4j logs for more detail." + + s"\n$historyLog", to) + case t: Throwable => throw t + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 3191b9975fbf..9ff9ecf7f367 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -20,20 +20,27 @@ package org.apache.spark.sql.hive import java.io.{File, PrintWriter} import scala.reflect.ClassTag +import scala.util.matching.Regex + +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { - test("Hive serde tables should fallback to HDFS for size estimation") { +class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { + test("Hive serde tables should fallback to HDFS for size estimation") { withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true") { withTable("csv_table") { withTempDir { tempDir => @@ -59,13 +66,13 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto |LOCATION '${tempDir.toURI}'""".stripMargin) val relation = spark.table("csv_table").queryExecution.analyzed.children.head - .asInstanceOf[CatalogRelation] + .asInstanceOf[HiveTableRelation] - val properties = relation.tableMeta.properties + val properties = relation.tableMeta.ignoredProperties assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") - val sizeInBytes = relation.stats(conf).sizeInBytes + val sizeInBytes = relation.stats.sizeInBytes assert(sizeInBytes === BigInt(file1.length() + file2.length())) } } @@ -74,92 +81,425 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("analyze Hive serde tables") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes // Non-partitioned table - sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() - sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + val nonPartTable = "non_part_table" + withTable(nonPartTable) { + sql(s"CREATE TABLE $nonPartTable (key STRING, value STRING)") + sql(s"INSERT INTO TABLE $nonPartTable SELECT * FROM src") + sql(s"INSERT INTO TABLE $nonPartTable SELECT * FROM src") - sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") + sql(s"ANALYZE TABLE $nonPartTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable") === BigInt(11624)) - - sql("DROP TABLE analyzeTable").collect() + assert(queryTotalSize(nonPartTable) === BigInt(11624)) + } // Partitioned table - sql( - """ - |CREATE TABLE analyzeTable_part (key STRING, value STRING) PARTITIONED BY (ds STRING) - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-01') - |SELECT * FROM src - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-02') - |SELECT * FROM src - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-03') - |SELECT * FROM src - """.stripMargin).collect() + val partTable = "part_table" + withTable(partTable) { + sql(s"CREATE TABLE $partTable (key STRING, value STRING) PARTITIONED BY (ds STRING)") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-01') SELECT * FROM src") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-02') SELECT * FROM src") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-03') SELECT * FROM src") - assert(queryTotalSize("analyzeTable_part") === spark.sessionState.conf.defaultSizeInBytes) + assert(queryTotalSize(partTable) === spark.sessionState.conf.defaultSizeInBytes) - sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") + sql(s"ANALYZE TABLE $partTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) - - sql("DROP TABLE analyzeTable_part").collect() + assert(queryTotalSize(partTable) === BigInt(17436)) + } // Try to analyze a temp table - sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") - intercept[AnalysisException] { - sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") + withView("tempTable") { + sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") + intercept[AnalysisException] { + sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") + } } - spark.sessionState.catalog.dropTable( - TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } - test("analyzing views is not supported") { - def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { - val err = intercept[AnalysisException] { - sql(analyzeCommand) + test("analyze non hive compatible datasource tables") { + val table = "parquet_tab" + withTable(table) { + sql( + s""" + |CREATE TABLE $table (a int, b int) + |USING parquet + |OPTIONS (skipHiveMetadata true) + """.stripMargin) + + // Verify that the schema stored in catalog is a dummy one used for + // data source tables. The actual schema is stored in table properties. + val rawSchema = hiveClient.getTable("default", table).schema + val metadata = new MetadataBuilder().putString("comment", "from deserializer").build() + val expectedRawSchema = new StructType().add("col", "array", true, metadata) + assert(rawSchema == expectedRawSchema) + + val actualSchema = spark.sharedState.externalCatalog.getTable("default", table).schema + val expectedActualSchema = new StructType() + .add("a", "int") + .add("b", "int") + assert(actualSchema == expectedActualSchema) + + sql(s"INSERT INTO $table VALUES (1, 1)") + sql(s"INSERT INTO $table VALUES (2, 1)") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS a, b") + val fetchedStats0 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetchedStats0.get.colStats == Map( + "a" -> ColumnStat(2, Some(1), Some(2), 0, 4, 4), + "b" -> ColumnStat(1, Some(1), Some(1), 0, 4, 4))) + } + } + + test("Analyze hive serde tables when schema is not same as schema in table properties") { + val table = "hive_serde" + withTable(table) { + sql(s"CREATE TABLE $table (C1 INT, C2 STRING, C3 DOUBLE)") + + // Verify that the table schema stored in hive catalog is + // different than the schema stored in table properties. + val rawSchema = hiveClient.getTable("default", table).schema + val expectedRawSchema = new StructType() + .add("c1", "int") + .add("c2", "string") + .add("c3", "double") + assert(rawSchema == expectedRawSchema) + + val actualSchema = spark.sharedState.externalCatalog.getTable("default", table).schema + val expectedActualSchema = new StructType() + .add("C1", "int") + .add("C2", "string") + .add("C3", "double") + assert(actualSchema == expectedActualSchema) + + sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS C1") + val fetchedStats1 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get + assert(fetchedStats1.colStats == Map( + "C1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, + avgLen = 4, maxLen = 4))) + } + } + + test("SPARK-21079 - analyze table with location different than that of individual partitions") { + val tableName = "analyzeTable_part" + withTable(tableName) { + withTempPath { path => + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') SELECT * FROM src") + } + + sql(s"ALTER TABLE $tableName SET LOCATION '${path.toURI}'") + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + + assert(getCatalogStatistics(tableName).sizeInBytes === BigInt(17436)) } - assert(err.message.contains("ANALYZE TABLE is not supported")) } + } - val tableName = "tbl" + test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { + val sourceTableName = "analyzeTable_part" + val tableName = "analyzeTable_part_vis" + withTable(sourceTableName, tableName) { + withTempPath { path => + // Create a table with 3 partitions all located under a single top-level directory 'path' + sql( + s""" + |CREATE TABLE $sourceTableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '${path.toURI}' + """.stripMargin) + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql( + s""" + |INSERT INTO TABLE $sourceTableName PARTITION (ds='$ds') + |SELECT * FROM src + """.stripMargin) + } + + // Create another table referring to the same location + sql( + s""" + |CREATE TABLE $tableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '${path.toURI}' + """.stripMargin) + + // Register only one of the partitions found on disk + val ds = partitionDates.head + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')") + + // Analyze original table - expect 3 partitions + sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan") + assert(getCatalogStatistics(sourceTableName).sizeInBytes === BigInt(3 * 5812)) + + // Analyze partial-copy table - expect only 1 partition + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + assert(getCatalogStatistics(tableName).sizeInBytes === BigInt(5812)) + } + } + } + + test("analyze single partition") { + val tableName = "analyzeTable_part" + + def queryStats(ds: String): CatalogStatistics = { + val partition = + spark.sessionState.catalog.getPartition(TableIdentifier(tableName), Map("ds" -> ds)) + partition.stats.get + } + + def createPartition(ds: String, query: String): Unit = { + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') $query") + } + + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + createPartition("2010-01-01", "SELECT '1', 'A' from src") + createPartition("2010-01-02", "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src") + createPartition("2010-01-03", "SELECT '1', 'A' from src") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS NOSCAN") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS NOSCAN") + + assert(queryStats("2010-01-01").rowCount === None) + assert(queryStats("2010-01-01").sizeInBytes === 2000) + + assert(queryStats("2010-01-02").rowCount === None) + assert(queryStats("2010-01-02").sizeInBytes === 2*2000) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS") + + assert(queryStats("2010-01-01").rowCount.get === 500) + assert(queryStats("2010-01-01").sizeInBytes === 2000) + + assert(queryStats("2010-01-02").rowCount.get === 2*500) + assert(queryStats("2010-01-02").sizeInBytes === 2*2000) + } + } + + test("analyze a set of partitions") { + val tableName = "analyzeTable_part" + + def queryStats(ds: String, hr: String): Option[CatalogStatistics] = { + val tableId = TableIdentifier(tableName) + val partition = + spark.sessionState.catalog.getPartition(tableId, Map("ds" -> ds, "hr" -> hr)) + partition.stats + } + + def assertPartitionStats( + ds: String, + hr: String, + rowCount: Option[BigInt], + sizeInBytes: BigInt): Unit = { + val stats = queryStats(ds, hr).get + assert(stats.rowCount === rowCount) + assert(stats.sizeInBytes === sizeInBytes) + } + + def createPartition(ds: String, hr: Int, query: String): Unit = { + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds', hr=$hr) $query") + } + + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING, hr INT)") + + createPartition("2010-01-01", 10, "SELECT '1', 'A' from src") + createPartition("2010-01-01", 11, "SELECT '1', 'A' from src") + createPartition("2010-01-02", 10, "SELECT '1', 'A' from src") + createPartition("2010-01-02", 11, + "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS NOSCAN") + + assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) + assert(queryStats("2010-01-02", "10") === None) + assert(queryStats("2010-01-02", "11") === None) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS NOSCAN") + + assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS") + + assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS") + + assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "10", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = Some(2*500), sizeInBytes = 2*2000) + } + } + + test("analyze all partitions") { + val tableName = "analyzeTable_part" + + def assertPartitionStats( + ds: String, + hr: String, + rowCount: Option[BigInt], + sizeInBytes: BigInt): Unit = { + val stats = spark.sessionState.catalog.getPartition(TableIdentifier(tableName), + Map("ds" -> ds, "hr" -> hr)).stats.get + assert(stats.rowCount === rowCount) + assert(stats.sizeInBytes === sizeInBytes) + } + + def createPartition(ds: String, hr: Int, query: String): Unit = { + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds', hr=$hr) $query") + } + + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING, hr INT)") + + createPartition("2010-01-01", 10, "SELECT '1', 'A' from src") + createPartition("2010-01-01", 11, "SELECT '1', 'A' from src") + createPartition("2010-01-02", 10, "SELECT '1', 'A' from src") + createPartition("2010-01-02", 11, + "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds, hr) COMPUTE STATISTICS NOSCAN") + + assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds, hr) COMPUTE STATISTICS") + + assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = Some(2*500), sizeInBytes = 2*2000) + } + } + + test("analyze partitions for an empty table") { + val tableName = "analyzeTable_part" + + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + // make sure there is no exception + sql(s"ANALYZE TABLE $tableName PARTITION (ds) COMPUTE STATISTICS NOSCAN") + + // make sure there is no exception + sql(s"ANALYZE TABLE $tableName PARTITION (ds) COMPUTE STATISTICS") + } + } + + test("analyze partitions case sensitivity") { + val tableName = "analyzeTable_part" withTable(tableName) { - spark.range(10).write.saveAsTable(tableName) - val viewName = "view" - withView(viewName) { - sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") - assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") - assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='2010-01-01') SELECT * FROM src") + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql(s"ANALYZE TABLE $tableName PARTITION (DS='2010-01-01') COMPUTE STATISTICS") + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val message = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName PARTITION (DS='2010-01-01') COMPUTE STATISTICS") + }.getMessage + assert(message.contains( + s"DS is not a valid partition column in table `default`.`${tableName.toLowerCase}`")) } } } - private def checkTableStats( - tableName: String, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats + test("analyze partial partition specifications") { - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes > 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) + val tableName = "analyzeTable_part" + + def assertAnalysisException(partitionSpec: String): Unit = { + val message = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName $partitionSpec COMPUTE STATISTICS") + }.getMessage + assert(message.contains("The list of partition columns with values " + + s"in partition specification for table '${tableName.toLowerCase}' in database 'default' " + + "is not a prefix of the list of partition columns defined in the table schema")) + } + + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName (key STRING, value STRING) + |PARTITIONED BY (a STRING, b INT, c STRING) + """.stripMargin) + + sql(s"INSERT INTO TABLE $tableName PARTITION (a='a1', b=10, c='c1') SELECT * FROM src") + + sql(s"ANALYZE TABLE $tableName PARTITION (a='a1') COMPUTE STATISTICS") + sql(s"ANALYZE TABLE $tableName PARTITION (a='a1', b=10) COMPUTE STATISTICS") + sql(s"ANALYZE TABLE $tableName PARTITION (A='a1', b=10) COMPUTE STATISTICS") + sql(s"ANALYZE TABLE $tableName PARTITION (b=10, a='a1') COMPUTE STATISTICS") + sql(s"ANALYZE TABLE $tableName PARTITION (b=10, A='a1') COMPUTE STATISTICS") + + assertAnalysisException("PARTITION (b=10)") + assertAnalysisException("PARTITION (a, b=10)") + assertAnalysisException("PARTITION (b=10, c='c1')") + assertAnalysisException("PARTITION (a, b=10, c='c1')") + assertAnalysisException("PARTITION (c='c1')") + assertAnalysisException("PARTITION (a, b, c='c1')") + assertAnalysisException("PARTITION (a='a1', c='c1')") + assertAnalysisException("PARTITION (a='a1', b, c='c1')") + } + } + + test("analyze non-existent partition") { + + def assertAnalysisException(analyzeCommand: String, errorMessage: String): Unit = { + val message = intercept[AnalysisException] { + sql(analyzeCommand) + }.getMessage + assert(message.contains(errorMessage)) } - stats + val tableName = "analyzeTable_part" + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='2010-01-01') SELECT * FROM src") + + assertAnalysisException( + s"ANALYZE TABLE $tableName PARTITION (hour=20) COMPUTE STATISTICS", + s"hour is not a valid partition column in table `default`.`${tableName.toLowerCase}`" + ) + + assertAnalysisException( + s"ANALYZE TABLE $tableName PARTITION (hour) COMPUTE STATISTICS", + s"hour is not a valid partition column in table `default`.`${tableName.toLowerCase}`" + ) + + intercept[NoSuchPartitionException] { + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2011-02-30') COMPUTE STATISTICS") + } + } } test("test table-level statistics for hive tables created in HiveExternalCatalog") { @@ -175,7 +515,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") checkTableStats( textTable, - hasSizeInBytes = false, + hasSizeInBytes = true, expectedRowCounts = None) // noscan won't count the number of rows @@ -191,27 +531,408 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("test elimination of the influences of the old stats") { + test("keep existing row count in stats with noscan if table is not changed") { val textTable = "textTable" withTable(textTable) { - sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") + sql(s"CREATE TABLE $textTable (key STRING, value STRING)") sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") val fetchedStats1 = checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - // when the total size is not changed, the old row count is kept + // when the table is not changed, total size is the same, and the old row count is kept val fetchedStats2 = checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1 == fetchedStats2) + } + } - sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") - sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - // update total size and remove the old and invalid row count + test("keep existing column stats if table is not changed") { + val table = "update_col_stats_table" + withTable(table) { + sql(s"CREATE TABLE $table (c1 INT, c2 STRING, c3 DOUBLE)") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats0 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetchedStats0.get.colStats == Map("c1" -> ColumnStat(0, None, None, 0, 4, 4))) + + // Insert new data and analyze: have the latest column stats. + sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats1 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get + assert(fetchedStats1.colStats == Map( + "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, + avgLen = 4, maxLen = 4))) + + // Analyze another column: since the table is not changed, the precious column stats are kept. + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") + val fetchedStats2 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get + assert(fetchedStats2.colStats == Map( + "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, + avgLen = 4, maxLen = 4), + "c2" -> ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, + avgLen = 1, maxLen = 1))) + + // Insert new data and analyze: stale column stats are removed and newly collected column + // stats are added. + sql(s"INSERT INTO TABLE $table SELECT 2, 'b', 20.0") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1, c3") val fetchedStats3 = - checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = None) - assert(fetchedStats3.get.sizeInBytes > fetchedStats2.get.sizeInBytes) + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)).get + assert(fetchedStats3.colStats == Map( + "c1" -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + "c3" -> ColumnStat(distinctCount = 2, min = Some(10.0), max = Some(20.0), nullCount = 0, + avgLen = 8, maxLen = 8))) + } + } + + private def createNonPartitionedTable( + tabName: String, + analyzedBySpark: Boolean = true, + analyzedByHive: Boolean = true): Unit = { + sql( + s""" + |CREATE TABLE $tabName (key STRING, value STRING) + |STORED AS TEXTFILE + |TBLPROPERTIES ('prop1' = 'val1', 'prop2' = 'val2') + """.stripMargin) + sql(s"INSERT INTO TABLE $tabName SELECT * FROM src") + if (analyzedBySpark) sql(s"ANALYZE TABLE $tabName COMPUTE STATISTICS") + // This is to mimic the scenario in which Hive genrates statistics before we reading it + if (analyzedByHive) hiveClient.runSqlHive(s"ANALYZE TABLE $tabName COMPUTE STATISTICS") + val describeResult1 = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") + + val tableMetadata = getCatalogTable(tabName).properties + // statistics info is not contained in the metadata of the original table + assert(Seq(StatsSetupConst.COLUMN_STATS_ACCURATE, + StatsSetupConst.NUM_FILES, + StatsSetupConst.NUM_PARTITIONS, + StatsSetupConst.ROW_COUNT, + StatsSetupConst.RAW_DATA_SIZE, + StatsSetupConst.TOTAL_SIZE).forall(!tableMetadata.contains(_))) + + if (analyzedByHive) { + assert(StringUtils.filterPattern(describeResult1, "*numRows\\s+500*").nonEmpty) + } else { + assert(StringUtils.filterPattern(describeResult1, "*numRows\\s+500*").isEmpty) + } + } + + private def extractStatsPropValues( + descOutput: Seq[String], + propKey: String): Option[BigInt] = { + val str = descOutput + .filterNot(_.contains(STATISTICS_PREFIX)) + .filter(_.contains(propKey)) + if (str.isEmpty) { + None + } else { + assert(str.length == 1, "found more than one matches") + val pattern = new Regex(s"""$propKey\\s+(-?\\d+)""") + val pattern(value) = str.head.trim + Option(BigInt(value)) + } + } + + test("get statistics when not analyzed in Hive or Spark") { + val tabName = "tab1" + withTable(tabName) { + createNonPartitionedTable(tabName, analyzedByHive = false, analyzedBySpark = false) + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = None) + + // ALTER TABLE SET TBLPROPERTIES invalidates some contents of Hive specific statistics + // This is triggered by the Hive alterTable API + val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") + + val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") + val numRows = extractStatsPropValues(describeResult, "numRows") + val totalSize = extractStatsPropValues(describeResult, "totalSize") + assert(rawDataSize.isEmpty, "rawDataSize should not be shown without table analysis") + assert(numRows.isEmpty, "numRows should not be shown without table analysis") + assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + } + } + + test("alter table rename after analyze table") { + Seq(true, false).foreach { analyzedBySpark => + val oldName = "tab1" + val newName = "tab2" + withTable(oldName, newName) { + createNonPartitionedTable(oldName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) + val fetchedStats1 = checkTableStats( + oldName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + sql(s"ALTER TABLE $oldName RENAME TO $newName") + val fetchedStats2 = checkTableStats( + newName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + assert(fetchedStats1 == fetchedStats2) + + // ALTER TABLE RENAME does not affect the contents of Hive specific statistics + val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $newName") + + val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") + val numRows = extractStatsPropValues(describeResult, "numRows") + val totalSize = extractStatsPropValues(describeResult, "totalSize") + assert(rawDataSize.isDefined && rawDataSize.get > 0, "rawDataSize is lost") + assert(numRows.isDefined && numRows.get == 500, "numRows is lost") + assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + } + } + } + + test("alter table should not have the side effect to store statistics in Spark side") { + val table = "alter_table_side_effect" + withTable(table) { + sql(s"CREATE TABLE $table (i string, j string)") + sql(s"INSERT INTO TABLE $table SELECT 'a', 'b'") + val catalogTable1 = getCatalogTable(table) + val hiveSize1 = BigInt(catalogTable1.ignoredProperties(StatsSetupConst.TOTAL_SIZE)) + + sql(s"ALTER TABLE $table SET TBLPROPERTIES ('prop1' = 'a')") + + sql(s"INSERT INTO TABLE $table SELECT 'c', 'd'") + val catalogTable2 = getCatalogTable(table) + val hiveSize2 = BigInt(catalogTable2.ignoredProperties(StatsSetupConst.TOTAL_SIZE)) + // After insertion, Hive's stats should be changed. + assert(hiveSize2 > hiveSize1) + // We haven't generate stats in Spark, so we should still use Hive's stats here. + assert(catalogTable2.stats.get.sizeInBytes == hiveSize2) + } + } + + private def testAlterTableProperties(tabName: String, alterTablePropCmd: String): Unit = { + Seq(true, false).foreach { analyzedBySpark => + withTable(tabName) { + createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + + // Run ALTER TABLE command + sql(alterTablePropCmd) + + val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") + + val totalSize = extractStatsPropValues(describeResult, "totalSize") + assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + + // ALTER TABLE SET/UNSET TBLPROPERTIES invalidates some Hive specific statistics, but not + // Spark specific statistics. This is triggered by the Hive alterTable API. + val numRows = extractStatsPropValues(describeResult, "numRows") + assert(numRows.isDefined && numRows.get == -1, "numRows is lost") + val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") + assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") + + if (analyzedBySpark) { + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + } else { + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = None) + } + } + } + } + + test("alter table SET TBLPROPERTIES after analyze table") { + testAlterTableProperties("set_prop_table", + "ALTER TABLE set_prop_table SET TBLPROPERTIES ('foo' = 'a')") + } + + test("alter table UNSET TBLPROPERTIES after analyze table") { + testAlterTableProperties("unset_prop_table", + "ALTER TABLE unset_prop_table UNSET TBLPROPERTIES ('prop1')") + } + + /** + * To see if stats exist, we need to check spark's stats properties instead of catalog + * statistics, because hive would change stats in metastore and thus change catalog statistics. + */ + private def getStatsProperties(tableName: String): Map[String, String] = { + val hTable = hiveClient.getTable(spark.sessionState.catalog.getCurrentDatabase, tableName) + hTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + } + + test("change stats after insert command for hive table") { + val table = s"change_stats_insert_hive_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string)") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } + } + } + + test("change stats after load data command") { + val table = "change_stats_load_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + withTempDir { loadPath => + // load data command + val file = new File(loadPath + "/data") + val writer = new PrintWriter(file) + writer.write("2,xyz") + writer.close() + sql(s"LOAD DATA INPATH '${loadPath.toURI.toString}' INTO TABLE $table") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } + } + } + } + + test("change stats after add/drop partition command") { + val table = "change_stats_part_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") + // table has two partitions initially + for (ds <- Seq("2008-04-08"); hr <- Seq("11", "12")) { + sql(s"INSERT OVERWRITE TABLE $table PARTITION (ds='$ds',hr='$hr') SELECT 1, 'a'") + } + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + withTempPaths(numPaths = 2) { case Seq(dir1, dir2) => + val file1 = new File(dir1 + "/data") + val writer1 = new PrintWriter(file1) + writer1.write("1,a") + writer1.close() + + val file2 = new File(dir2 + "/data") + val writer2 = new PrintWriter(file2) + writer2.write("1,a") + writer2.close() + + // add partition command + sql( + s""" + |ALTER TABLE $table ADD + |PARTITION (ds='2008-04-09', hr='11') LOCATION '${dir1.toURI.toString}' + |PARTITION (ds='2008-04-09', hr='12') LOCATION '${dir2.toURI.toString}' + """.stripMargin) + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > fetched1.get.sizeInBytes) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + + // now the table has four partitions, generate stats again + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched3 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(4)) + assert(fetched3.get.sizeInBytes > 0) + assert(fetched3.get.colStats.size == 2) + + // drop partition command + sql(s"ALTER TABLE $table DROP PARTITION (ds='2008-04-08'), PARTITION (hr='12')") + assert(spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + .map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + // only one partition left + if (autoUpdate) { + val fetched4 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched4.get.sizeInBytes < fetched1.get.sizeInBytes) + assert(fetched4.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched4.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } + } + } + } + + test("add/drop partitions - managed table") { + val catalog = spark.sessionState.catalog + val managedTable = "partitionedTable" + withTable(managedTable) { + sql( + s""" + |CREATE TABLE $managedTable (key INT, value STRING) + |PARTITIONED BY (ds STRING, hr STRING) + """.stripMargin) + + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + sql( + s""" + |INSERT OVERWRITE TABLE $managedTable + |partition (ds='$ds',hr='$hr') + |SELECT 1, 'a' + """.stripMargin) + } + + checkTableStats( + managedTable, hasSizeInBytes = false, expectedRowCounts = None) + + sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") + + val stats1 = checkTableStats( + managedTable, hasSizeInBytes = true, expectedRowCounts = Some(4)) + + sql( + s""" + |ALTER TABLE $managedTable DROP PARTITION (ds='2008-04-08'), + |PARTITION (hr='12') + """.stripMargin) + assert(catalog.listPartitions(TableIdentifier(managedTable)).map(_.spec).toSet == + Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + + sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") + + val stats2 = checkTableStats( + managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) + assert(stats1.get.sizeInBytes > stats2.get.sizeInBytes) + + sql(s"ALTER TABLE $managedTable ADD PARTITION (ds='2008-04-08', hr='12')") + sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") + val stats4 = checkTableStats( + managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) + + assert(stats1.get.sizeInBytes > stats4.get.sizeInBytes) + assert(stats4.get.sizeInBytes == stats2.get.sizeInBytes) } } @@ -226,13 +947,14 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it // for robustness - withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "true") { + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true") { checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } - withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { - checkTableStats(orcTable, hasSizeInBytes = false, expectedRowCounts = None) + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { + // We still can get tableSize from Hive before Analyze + checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } @@ -254,7 +976,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) // Validate statistics - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client val table = hiveClient.getTable("default", tableName) val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) @@ -348,8 +1069,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val parquetTable = "parquetTable" withTable(parquetTable) { sql(createTableCmd) - val catalogTable = spark.sessionState.catalog.getTableMetadata( - TableIdentifier(parquetTable)) + val catalogTable = getCatalogTable(parquetTable) assert(DDLUtils.isDatasourceTable(catalogTable)) // Add a filter to avoid creating too many partitions @@ -384,17 +1104,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto "partitioned data source table", "CREATE TABLE parquetTable (key STRING, value STRING) USING PARQUET PARTITIONED BY (key)") - test("statistics collection of a table with zero column") { - val table_no_cols = "table_no_cols" - withTable(table_no_cols) { - val rddNoCols = sparkContext.parallelize(1 to 10).map(_ => Row.empty) - val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) - dfNoCols.write.format("json").saveAsTable(table_no_cols) - sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") - checkTableStats(table_no_cols, hasSizeInBytes = true, expectedRowCounts = Some(10)) - } - } - /** Used to test refreshing cached metadata once table stats are updated. */ private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean) : (CatalogStatistics, CatalogStatistics) = { @@ -442,7 +1151,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("estimates the size of a test Hive serde tables") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: HiveTableRelation => relation.stats.sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -462,7 +1171,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold, @@ -502,7 +1211,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto () => (), metastoreQuery, metastoreAnswer, - implicitly[ClassTag[CatalogRelation]] + implicitly[ClassTag[HiveTableRelation]] ) } @@ -516,7 +1225,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: HiveTableRelation => relation.stats.sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala new file mode 100644 index 000000000000..72f8e8ff7c68 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.test.{TestHiveSingleton, TestHiveSparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + + +class TestHiveSuite extends TestHiveSingleton with SQLTestUtils { + test("load test table based on case sensitivity") { + val testHiveSparkSession = spark.asInstanceOf[TestHiveSparkSession] + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("SELECT * FROM SRC").queryExecution.analyzed + assert(testHiveSparkSession.getLoadedTables.contains("src")) + assert(testHiveSparkSession.getLoadedTables.size == 1) + } + testHiveSparkSession.reset() + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val err = intercept[AnalysisException] { + sql("SELECT * FROM SRC").queryExecution.analyzed + } + assert(err.message.contains("Table or view not found")) + } + testHiveSparkSession.reset() + } + + test("SPARK-15887: hive-site.xml should be loaded") { + assert(hiveClient.getConf("hive.in.test", "") == "true") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index e85ea5a59427..ae804ce7c7b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -25,9 +25,7 @@ import org.apache.hadoop.util.VersionInfo import org.apache.spark.SparkConf import org.apache.spark.util.Utils -private[client] class HiveClientBuilder { - private val sparkConf = new SparkConf() - +private[client] object HiveClientBuilder { // In order to speed up test execution during development or in Jenkins, you can specify the path // of an existing Ivy cache: private val ivyPath: Option[String] = { @@ -52,7 +50,7 @@ private[client] class HiveClientBuilder { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, + sparkConf = new SparkConf(), hadoopConf = hadoopConf, config = buildConf(extraConf), ivyPath = ivyPath).createClient() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 4790331168bd..3eedcf7e0874 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -19,21 +19,22 @@ package org.apache.spark.sql.hive.client import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} -import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -class HiveClientSuite extends SparkFunSuite { - private val clientBuilder = new HiveClientBuilder +// TODO: Refactor this to `HivePartitionFilteringSuite` +class HiveClientSuite(version: String) + extends HiveVersionSuite(version) with BeforeAndAfterAll { + import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname - test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { - val testPartitionCount = 5 + private val testPartitionCount = 3 * 24 * 4 + private def init(tryDirectSql: Boolean): HiveClient = { val storageFormat = CatalogStorageFormat( locationUri = None, inputFormat = None, @@ -43,19 +44,214 @@ class HiveClientSuite extends SparkFunSuite { properties = Map.empty) val hadoopConf = new Configuration() - hadoopConf.setBoolean(tryDirectSqlKey, false) - val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) - client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)") + hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) + val client = buildClient(hadoopConf) + client + .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + + val partitions = + for { + ds <- 20170101 to 20170103 + h <- 0 to 23 + chunk <- Seq("aa", "ab", "ba", "bb") + } yield CatalogTablePartition(Map( + "ds" -> ds.toString, + "h" -> h.toString, + "chunk" -> chunk + ), storageFormat) + assert(partitions.size == testPartitionCount) - val partitions = (1 to testPartitionCount).map { part => - CatalogTablePartition(Map("part" -> part.toString), storageFormat) - } client.createPartitions( "default", "test", partitions, ignoreIfExists = false) + client + } + override def beforeAll() { + client = init(true) + } + + test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { + val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3)))) + Seq(parseExpression("ds=20170101"))) assert(filteredPartitions.size == testPartitionCount) } + + test("getPartitionsByFilter: ds=20170101") { + testMetastorePartitionFiltering( + "ds=20170101", + 20170101 to 20170101, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds=(20170101 + 1) and h=0") { + // Should return all partitions where h=0 because getPartitionsByFilter does not support + // comparisons to non-literal values + testMetastorePartitionFiltering( + "ds=(20170101 + 1) and h=0", + 20170101 to 20170103, + 0 to 0, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: chunk='aa'") { + testMetastorePartitionFiltering( + "chunk='aa'", + 20170101 to 20170103, + 0 to 23, + "aa" :: Nil) + } + + test("getPartitionsByFilter: 20170101=ds") { + testMetastorePartitionFiltering( + "20170101=ds", + 20170101 to 20170101, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds=20170101 and h=10") { + testMetastorePartitionFiltering( + "ds=20170101 and h=10", + 20170101 to 20170101, + 10 to 10, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds=20170101 or ds=20170102") { + testMetastorePartitionFiltering( + "ds=20170101 or ds=20170102", + 20170101 to 20170102, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { + testMetastorePartitionFiltering( + "ds in (20170102, 20170103)", + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { + testMetastorePartitionFiltering( + "ds in (20170102, 20170103)", + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) + }) + } + + test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { + testMetastorePartitionFiltering( + "chunk in ('ab', 'ba')", + 20170101 to 20170103, + 0 to 23, + "ab" :: "ba" :: Nil) + } + + test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { + testMetastorePartitionFiltering( + "chunk in ('ab', 'ba')", + 20170101 to 20170103, + 0 to 23, + "ab" :: "ba" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) + }) + } + + test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { + val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) + val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) + testMetastorePartitionFiltering( + "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", + day1 :: day2 :: Nil) + } + + test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { + val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) + // Day 2 should include all hours because we can't build a filter for h<(7+1) + val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) + testMetastorePartitionFiltering( + "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", + day1 :: day2 :: Nil) + } + + test("getPartitionsByFilter: " + + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { + val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) + val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) + testMetastorePartitionFiltering( + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", + day1 :: day2 :: Nil) + } + + private def testMetastorePartitionFiltering( + filterString: String, + expectedDs: Seq[Int], + expectedH: Seq[Int], + expectedChunks: Seq[String]): Unit = { + testMetastorePartitionFiltering( + filterString, + (expectedDs, expectedH, expectedChunks) :: Nil, + identity) + } + + private def testMetastorePartitionFiltering( + filterString: String, + expectedDs: Seq[Int], + expectedH: Seq[Int], + expectedChunks: Seq[String], + transform: Expression => Expression): Unit = { + testMetastorePartitionFiltering( + filterString, + (expectedDs, expectedH, expectedChunks) :: Nil, + identity) + } + + private def testMetastorePartitionFiltering( + filterString: String, + expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { + testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + } + + private def testMetastorePartitionFiltering( + filterString: String, + expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], + transform: Expression => Expression): Unit = { + val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), + Seq( + transform(parseExpression(filterString)) + )) + + val expectedPartitionCount = expectedPartitionCubes.map { + case (expectedDs, expectedH, expectedChunks) => + expectedDs.size * expectedH.size * expectedChunks.size + }.sum + + val expectedPartitions = expectedPartitionCubes.map { + case (expectedDs, expectedH, expectedChunks) => + for { + ds <- expectedDs + h <- expectedH + chunk <- expectedChunks + } yield Set( + "ds" -> ds.toString, + "h" -> h.toString, + "chunk" -> chunk + ) + }.reduce(_ ++ _) + + val actualFilteredPartitionCount = filteredPartitions.size + + assert(actualFilteredPartitionCount == expectedPartitionCount, + s"Expected $expectedPartitionCount partitions but got $actualFilteredPartitionCount") + assert(filteredPartitions.map(_.spec.toSet).toSet == expectedPartitions.toSet) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala new file mode 100644 index 000000000000..de1be2115b2d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.Suite + +class HiveClientSuites extends Suite with HiveClientVersions { + override def nestedSuites: IndexedSeq[Suite] = { + // Hive 0.12 does not provide the partition filtering API we call + versions.filterNot(_ == "0.12").map(new HiveClientSuite(_)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala new file mode 100644 index 000000000000..2e7dfde8b2fa --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import scala.collection.immutable.IndexedSeq + +import org.apache.spark.SparkFunSuite + +private[client] trait HiveClientVersions { + protected val versions = IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala new file mode 100644 index 000000000000..951ebfad4590 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import org.apache.hadoop.conf.Configuration +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.hive.HiveUtils + +private[client] abstract class HiveVersionSuite(version: String) extends SparkFunSuite { + protected var client: HiveClient = null + + protected def buildClient(hadoopConf: Configuration): HiveClient = { + // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and + // hive.metastore.schema.verification from false to true since 2.0 + // For details, see the JIRA HIVE-6113 and HIVE-12463 + if (version == "2.0" || version == "2.1") { + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") + } + HiveClientBuilder + .buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) + } + + override def suiteName: String = s"${super.suiteName}($version)" + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(s"$version: $testName", testTags: _*)(testFun) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 7aff49c0fc3b..edb9a9ffbaaf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,7 +21,6 @@ import java.io.{ByteArrayOutputStream, File, PrintStream} import java.net.URI import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -47,11 +46,11 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ +// TODO: Refactor this to `HiveClientSuite` and make it a subclass of `HiveVersionSuite` @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - private val clientBuilder = new HiveClientBuilder - import clientBuilder.buildClient + import HiveClientBuilder.buildClient /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` @@ -128,7 +127,7 @@ class VersionsSuite extends SparkFunSuite with Logging { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } - client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client @@ -233,12 +232,49 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(client.getTable("default", "src").properties.contains("changed")) } - test(s"$version: alterTable(tableName: String, table: CatalogTable)") { + test(s"$version: alterTable(dbName: String, tableName: String, table: CatalogTable)") { val newTable = client.getTable("default", "src").copy(properties = Map("changedAgain" -> "")) - client.alterTable("src", newTable) + client.alterTable("default", "src", newTable) assert(client.getTable("default", "src").properties.contains("changedAgain")) } + test(s"$version: alterTable - rename") { + val newTable = client.getTable("default", "src") + .copy(identifier = TableIdentifier("tgt", database = Some("default"))) + assert(!client.tableExists("default", "tgt")) + + client.alterTable("default", "src", newTable) + + assert(client.tableExists("default", "tgt")) + assert(!client.tableExists("default", "src")) + } + + test(s"$version: alterTable - change database") { + val tempDB = CatalogDatabase( + "temporary", description = "test create", tempDatabasePath, Map()) + client.createDatabase(tempDB, ignoreIfExists = true) + + val newTable = client.getTable("default", "tgt") + .copy(identifier = TableIdentifier("tgt", database = Some("temporary"))) + assert(!client.tableExists("temporary", "tgt")) + + client.alterTable("default", "tgt", newTable) + + assert(client.tableExists("temporary", "tgt")) + assert(!client.tableExists("default", "tgt")) + } + + test(s"$version: alterTable - change database and table names") { + val newTable = client.getTable("temporary", "tgt") + .copy(identifier = TableIdentifier("src", database = Some("default"))) + assert(!client.tableExists("default", "src")) + + client.alterTable("temporary", "tgt", newTable) + + assert(client.tableExists("default", "src")) + assert(!client.tableExists("temporary", "tgt")) + } + test(s"$version: listTables(database)") { assert(client.listTables("default") === Seq("src", "temporary")) } @@ -576,7 +612,7 @@ class VersionsSuite extends SparkFunSuite with Logging { versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) val tableMeta = versionSpark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")) - val totalSize = tableMeta.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val totalSize = tableMeta.stats.map(_.sizeInBytes) // Except 0.12, all the following versions will fill the Hive-generated statistics if (version == "0.12") { assert(totalSize.isEmpty) @@ -697,6 +733,114 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(versionSpark.table("t1").collect() === Array(Row(2))) } } + + test(s"$version: Decimal support of Avro Hive serde") { + val tableName = "tab1" + // TODO: add the other logical types. For details, see the link: + // https://avro.apache.org/docs/1.8.1/spec.html#Logical+Types + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | } ] + |} + """.stripMargin + + Seq(true, false).foreach { isPartitioned => + withTable(tableName) { + val partitionClause = if (isPartitioned) "PARTITIONED BY (ds STRING)" else "" + // Creates the (non-)partitioned Avro table + versionSpark.sql( + s""" + |CREATE TABLE $tableName + |$partitionClause + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + + val errorMsg = "data type mismatch: cannot cast DecimalType(2,1) to BinaryType" + + if (isPartitioned) { + val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3" + if (version == "0.12" || version == "0.13") { + val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage + assert(e.contains(errorMsg)) + } else { + versionSpark.sql(insertStmt) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30, 'a'").collect()) + } + } else { + val insertStmt = s"INSERT OVERWRITE TABLE $tableName SELECT 1.3" + if (version == "0.12" || version == "0.13") { + val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage + assert(e.contains(errorMsg)) + } else { + versionSpark.sql(insertStmt) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30").collect()) + } + } + } + } + } + + test(s"$version: read avro file containing decimal") { + val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") + val location = new File(url.getFile) + + val tableName = "tab1" + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | } ] + |} + """.stripMargin + withTable(tableName) { + versionSpark.sql( + s""" + |CREATE TABLE $tableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$location' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30").collect()) + } + } + // TODO: add more tests. } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 84f915977bd8..f245a79f805a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ import scala.util.Random +import test.org.apache.spark.sql.MyDoubleAvg +import test.org.apache.spark.sql.MyDoubleSum + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ + class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { def inputSchema: StructType = schema diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index abe5d835719b..cee82cda4628 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} @@ -192,12 +193,7 @@ abstract class HiveComparisonTest "last_modified_by", "last_modified_time", "Owner:", - "COLUMN_STATS_ACCURATE", // The following are hive specific schema parameters which we do not need to match exactly. - "numFiles", - "numRows", - "rawDataSize", - "totalSize", "totalNumberFiles", "maxFileSize", "minFileSize" @@ -346,7 +342,10 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath)) - try { (query, prepareAnswer(query, query.hiveResultString())) } catch { + def getResult(): Seq[String] = { + SQLExecution.withNewExecutionId(query.sparkSession, query)(query.hiveResultString()) + } + try { (query, prepareAnswer(query, getResult())) } catch { case e: Throwable => val errorMessage = s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 16a99321bad3..668da5fb4732 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -20,22 +20,25 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.net.URI +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils // TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach { @@ -50,15 +53,28 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean): CatalogTable = { val storage = - CatalogStorageFormat( - locationUri = Some(catalog.defaultTablePath(name)), - inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), - serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), - compressed = false, - properties = Map("serialization.format" -> "1")) + if (isDataSource) { + val serde = HiveSerDe.sourceToSerDe("parquet") + assert(serde.isDefined, "The default format is not Hive compatible") + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = serde.get.inputFormat, + outputFormat = serde.get.outputFormat, + serde = serde.get.serde, + compressed = false, + properties = Map("serialization.format" -> "1")) + } else { + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), + compressed = false, + properties = Map("serialization.format" -> "1")) + } val metadata = new MetadataBuilder() .putString("key", "value") .build() @@ -71,9 +87,10 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA .add("col2", "string") .add("a", "int") .add("b", "int"), - provider = Some("hive"), + provider = if (isDataSource) Some("parquet") else Some("hive"), partitionColumnNames = Seq("a", "b"), createTime = 0L, + createVersion = org.apache.spark.SPARK_VERSION, tracksPartitionsInCatalog = true) } @@ -107,6 +124,45 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA ) } + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set properties") { + testSetProperties(isDatasourceTable = false) + } + + test("alter table: unset properties") { + testUnsetProperties(isDatasourceTable = false) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde partition") { + testSetSerdePartition(isDatasourceTable = false) + } + + test("alter table: change column") { + testChangeColumn(isDatasourceTable = false) + } + + test("alter table: rename partition") { + testRenamePartitions(isDatasourceTable = false) + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } } class HiveDDLSuite @@ -130,7 +186,7 @@ class HiveDDLSuite if (dbPath.isEmpty) { hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier) } else { - new Path(new Path(dbPath.get), tableIdentifier.table) + new Path(new Path(dbPath.get), tableIdentifier.table).toUri } val filesystemPath = new Path(expectedTablePath.toString) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) @@ -292,7 +348,7 @@ class HiveDDLSuite val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int) PARTITIONED BY (a string)") } - assert(e.message == "Found duplicate column(s) in table definition of `default`.`tbl`: a") + assert(e.message == "Found duplicate column(s) in the table definition of `default`.`tbl`: `a`") } test("add/drop partition with location - managed table") { @@ -620,7 +676,7 @@ class HiveDDLSuite |""".stripMargin) val newPart = catalog.getPartition(TableIdentifier("boxes"), Map("width" -> "4")) assert(newPart.storage.serde == Some(expectedSerde)) - assume(newPart.storage.properties.filterKeys(expectedSerdeProps.contains) == + assert(newPart.storage.properties.filterKeys(expectedSerdeProps.contains) == expectedSerdeProps) } @@ -724,6 +780,26 @@ class HiveDDLSuite } } + test("desc table for Hive table - bucketed + sorted table") { + withTable("tbl") { + sql(s""" + CREATE TABLE tbl (id int, name string) + PARTITIONED BY (ds string) + CLUSTERED BY(id) + SORTED BY(id, name) INTO 1024 BUCKETS + """) + + val x = sql("DESC FORMATTED tbl").collect() + assert(x.containsSlice( + Seq( + Row("Num Buckets", "1024", ""), + Row("Bucket Columns", "[`id`]", ""), + Row("Sort Columns", "[`id`, `name`]", "") + ) + )) + } + } + test("desc table for data source table using Hive Metastore") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") val tabName = "tab1" @@ -732,7 +808,7 @@ class HiveDDLSuite checkAnswer( sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("# col_name", "data_type", "comment") :: Row("a", "int", "test") :: Nil + Row("a", "int", "test") :: Nil ) } } @@ -1126,11 +1202,6 @@ class HiveDDLSuite "last_modified_by", "last_modified_time", "Owner:", - "COLUMN_STATS_ACCURATE", - "numFiles", - "numRows", - "rawDataSize", - "totalSize", "totalNumberFiles", "maxFileSize", "minFileSize" @@ -1583,7 +1654,7 @@ class HiveDDLSuite test("create hive table with a non-existing location") { withTable("t", "t1") { withTempPath { dir => - spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '$dir'") + spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '${dir.toURI}'") val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1600,7 +1671,7 @@ class HiveDDLSuite |CREATE TABLE t1(a int, b int) |USING hive |PARTITIONED BY(a) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -1628,7 +1699,7 @@ class HiveDDLSuite s""" |CREATE TABLE t |USING hive - |LOCATION '$dir' + |LOCATION '${dir.toURI}' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1644,7 +1715,7 @@ class HiveDDLSuite |CREATE TABLE t1 |USING hive |PARTITIONED BY(a, b) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -1670,21 +1741,21 @@ class HiveDDLSuite |CREATE TABLE t(a string, `$specialChars` string) |USING $datasource |PARTITIONED BY(`$specialChars`) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' """.stripMargin) assert(dir.listFiles().isEmpty) spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" val partFile = new File(dir, partEscaped) - assert(partFile.listFiles().length >= 1) + assert(partFile.listFiles().nonEmpty) checkAnswer(spark.table("t"), Row("1", "2") :: Nil) withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`) SELECT 3, 4") val partEscaped1 = s"${ExternalCatalogUtils.escapePathName(specialChars)}=4" val partFile1 = new File(dir, partEscaped1) - assert(partFile1.listFiles().length >= 1) + assert(partFile1.listFiles().nonEmpty) checkAnswer(spark.table("t"), Row("1", "2") :: Row("3", "4") :: Nil) } } @@ -1695,15 +1766,22 @@ class HiveDDLSuite Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"hive table: location uri contains $specialChars") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + withTable("t") { withTempDir { dir => val loc = new File(dir, specialChars) loc.mkdir() + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t(a string) |USING hive - |LOCATION '$loc' + |LOCATION '$escapedLoc' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1726,12 +1804,13 @@ class HiveDDLSuite withTempDir { dir => val loc = new File(dir, specialChars) loc.mkdir() + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t1(a string, b string) |USING hive |PARTITIONED BY(b) - |LOCATION '$loc' + |LOCATION '$escapedLoc' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -1742,16 +1821,20 @@ class HiveDDLSuite if (specialChars != "a:b") { spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") val partFile = new File(loc, "b=2") - assert(partFile.listFiles().length >= 1) + assert(partFile.listFiles().nonEmpty) checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") assert(!partFile1.exists()) - val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") - assert(partFile2.listFiles().length >= 1) - checkAnswer(spark.table("t1"), - Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + + if (!Utils.isWindows) { + // Actual path becomes "b=2017-03-03%2012%3A13%253A14" on Windows. + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().nonEmpty) + checkAnswer(spark.table("t1"), + Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } } else { val e = intercept[AnalysisException] { spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") @@ -1875,4 +1958,55 @@ class HiveDDLSuite } } } + + test("SPARK-21216: join with a streaming DataFrame") { + import org.apache.spark.sql.execution.streaming.MemoryStream + import testImplicits._ + + implicit val _sqlContext = spark.sqlContext + + Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word").createOrReplaceTempView("t1") + // Make a table and ensure it will be broadcast. + sql("""CREATE TABLE smallTable(word string, number int) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |STORED AS TEXTFILE + """.stripMargin) + + sql( + """INSERT INTO smallTable + |SELECT word, number from t1 + """.stripMargin) + + val inputData = MemoryStream[Int] + val joined = inputData.toDS().toDF() + .join(spark.table("smallTable"), $"value" === $"number") + + val sq = joined.writeStream + .format("memory") + .queryName("t2") + .start() + try { + inputData.addData(1, 2) + + sq.processAllAvailable() + + checkAnswer( + spark.table("t2"), + Seq(Row(1, "one", 1), Row(2, "two", 2)) + ) + } finally { + sq.stop() + } + } + + test("table name with schema") { + // regression test for SPARK-11778 + withDatabase("usrdb") { + spark.sql("create schema usrdb") + withTable("usrdb.test") { + spark.sql("create table usrdb.test(c int)") + spark.read.table("usrdb.test") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index aa1ca2909074..3066a4f305f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -29,6 +29,12 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto import testImplicits._ test("show cost in explain command") { + // For readability, we only show optimized plan and physical plan in explain cost command + checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), + "Optimized Logical Plan", "Physical Plan") + checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), + "Parsed Logical Plan", "Analyzed Logical Plan") + // Only has sizeInBytes before ANALYZE command checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes") checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), "rowCount") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index 89e6edb6b157..78cdc67800c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton class HivePlanTest extends QueryTest with TestHiveSingleton { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index cf3376036072..2ea51791d0f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -370,21 +370,23 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd """.stripMargin) test("SPARK-7270: consider dynamic partition when comparing table output") { - sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") - sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") + withTable("test_partition", "ptest") { + sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") + sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") - val analyzedPlan = sql( - """ + val analyzedPlan = sql( + """ |INSERT OVERWRITE table test_partition PARTITION (b=1, c) |SELECT 'a', 'c' from ptest """.stripMargin).queryExecution.analyzed - assertResult(false, "Incorrect cast detected\n" + analyzedPlan) { + assertResult(false, "Incorrect cast detected\n" + analyzedPlan) { var hasCast = false - analyzedPlan.collect { - case p: Project => p.transformExpressionsUp { case c: Cast => hasCast = true; c } + analyzedPlan.collect { + case p: Project => p.transformExpressionsUp { case c: Cast => hasCast = true; c } + } + hasCast } - hasCast } } @@ -435,13 +437,13 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd test("transform with SerDe2") { assume(TestUtils.testCommandAvailable("/bin/bash")) + withTable("small_src") { + sql("CREATE TABLE small_src(key INT, value STRING)") + sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10") - sql("CREATE TABLE small_src(key INT, value STRING)") - sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10") - - val expected = sql("SELECT key FROM small_src").collect().head - val res = sql( - """ + val expected = sql("SELECT key FROM small_src").collect().head + val res = sql( + """ |SELECT TRANSFORM (key) ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.avro.AvroSerDe' |WITH SERDEPROPERTIES ('avro.schema.literal'='{"namespace": @@ -453,7 +455,8 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd |FROM small_src """.stripMargin.replaceAll(System.lineSeparator(), " ")).collect().head - assert(expected(0) === res(0)) + assert(expected(0) === res(0)) + } } createQueryTest("transform with SerDe3", @@ -780,22 +783,26 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd test("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)") + withTable(tableName) { + val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)") - // If the table was not created, the following assertion would fail - assert(Try(table(tableName)).isSuccess) + // If the table was not created, the following assertion would fail + assert(Try(table(tableName)).isSuccess) - // If the CREATE TABLE command got executed again, the following assertion would fail - assert(Try(q0.count()).isSuccess) + // If the CREATE TABLE command got executed again, the following assertion would fail + assert(Try(q0.count()).isSuccess) + } } test("SPARK-2263: Insert Map values") { - sql("CREATE TABLE m(value MAP)") - sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach { - case (Row(map: Map[_, _]), Row(key: Int, value: String)) => - assert(map.size === 1) - assert(map.head === (key, value)) + withTable("m") { + sql("CREATE TABLE m(value MAP)") + sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach { + case (Row(map: Map[_, _]), Row(key: Int, value: String)) => + assert(map.size === 1) + assert(map.head === ((key, value))) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 5afb37b382e6..97e4c2b6b2db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.types.StructType * A test suite for Hive view related functionality. */ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { - protected override val spark: SparkSession = TestHive.sparkSession - import testImplicits._ test("create a permanent/temp view using a hive, built-in, and permanent user function") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 7803ac39e508..1c9f00141ae1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,15 +17,23 @@ package org.apache.spark.sql.hive.execution +import java.net.URI + import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.types.StructType /** * A set of tests that validates support for Hive SerDe. */ -class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { +class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { import TestHive._ import org.apache.hadoop.hive.serde2.RegexSerDe @@ -60,4 +68,127 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF()) assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil) } + + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { + TestHive.sessionState.sqlParser.parsePlan(sql).collect { + case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + }.head + } + + private def analyzeCreateTable(sql: String): CatalogTable = { + TestHive.sessionState.analyzer.execute(TestHive.sessionState.sqlParser.parsePlan(sql)).collect { + case CreateTableCommand(tableDesc, _) => tableDesc + }.head + } + + test("Test the default fileformat for Hive-serde tables") { + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + withSQLConf("hive.default.fileformat" -> "parquet") { + val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") + assert(exists) + val input = desc.storage.inputFormat + val output = desc.storage.outputFormat + val serde = desc.storage.serde + assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } + } + + test("create hive serde table with new syntax - basic") { + val sql = + """ + |CREATE TABLE t + |(id int, name string COMMENT 'blabla') + |USING hive + |OPTIONS (fileFormat 'parquet', my_prop 1) + |LOCATION '/tmp/file' + |COMMENT 'BLABLA' + """.stripMargin + + val table = analyzeCreateTable(sql) + assert(table.schema == new StructType() + .add("id", "int") + .add("name", "string", nullable = true, comment = "blabla")) + assert(table.provider == Some(DDLUtils.HIVE_PROVIDER)) + assert(table.storage.locationUri == Some(new URI("/tmp/file"))) + assert(table.storage.properties == Map("my_prop" -> "1")) + assert(table.comment == Some("BLABLA")) + + assert(table.storage.inputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(table.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(table.storage.serde == + Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } + + test("create hive serde table with new syntax - with partition and bucketing") { + val v1 = "CREATE TABLE t (c1 int, c2 int) USING hive PARTITIONED BY (c2)" + val table = analyzeCreateTable(v1) + assert(table.schema == new StructType().add("c1", "int").add("c2", "int")) + assert(table.partitionColumnNames == Seq("c2")) + // check the default formats + assert(table.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(table.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(table.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + + val v2 = "CREATE TABLE t (c1 int, c2 int) USING hive CLUSTERED BY (c2) INTO 4 BUCKETS" + val e2 = intercept[AnalysisException](analyzeCreateTable(v2)) + assert(e2.message.contains("Creating bucketed Hive serde table is not supported yet")) + + val v3 = + """ + |CREATE TABLE t (c1 int, c2 int) USING hive + |PARTITIONED BY (c2) + |CLUSTERED BY (c2) INTO 4 BUCKETS""".stripMargin + val e3 = intercept[AnalysisException](analyzeCreateTable(v3)) + assert(e3.message.contains("Creating bucketed Hive serde table is not supported yet")) + } + + test("create hive serde table with new syntax - Hive options error checking") { + val v1 = "CREATE TABLE t (c1 int) USING hive OPTIONS (inputFormat 'abc')" + val e1 = intercept[IllegalArgumentException](analyzeCreateTable(v1)) + assert(e1.getMessage.contains("Cannot specify only inputFormat or outputFormat")) + + val v2 = "CREATE TABLE t (c1 int) USING hive OPTIONS " + + "(fileFormat 'x', inputFormat 'a', outputFormat 'b')" + val e2 = intercept[IllegalArgumentException](analyzeCreateTable(v2)) + assert(e2.getMessage.contains( + "Cannot specify fileFormat and inputFormat/outputFormat together")) + + val v3 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', serde 'a')" + val e3 = intercept[IllegalArgumentException](analyzeCreateTable(v3)) + assert(e3.getMessage.contains("fileFormat 'parquet' already specifies a serde")) + + val v4 = "CREATE TABLE t (c1 int) USING hive OPTIONS (serde 'a', fieldDelim ' ')" + val e4 = intercept[IllegalArgumentException](analyzeCreateTable(v4)) + assert(e4.getMessage.contains("Cannot specify delimiters with a custom serde")) + + val v5 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fieldDelim ' ')" + val e5 = intercept[IllegalArgumentException](analyzeCreateTable(v5)) + assert(e5.getMessage.contains("Cannot specify delimiters without fileFormat")) + + val v6 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', fieldDelim ' ')" + val e6 = intercept[IllegalArgumentException](analyzeCreateTable(v6)) + assert(e6.getMessage.contains( + "Cannot specify delimiters as they are only compatible with fileFormat 'textfile'")) + + // The value of 'fileFormat' option is case-insensitive. + val v7 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'TEXTFILE', lineDelim ',')" + val e7 = intercept[IllegalArgumentException](analyzeCreateTable(v7)) + assert(e7.getMessage.contains("Hive data source only support newline '\\n' as line delimiter")) + + val v8 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'wrong')" + val e8 = intercept[IllegalArgumentException](analyzeCreateTable(v8)) + assert(e8.getMessage.contains("invalid fileFormat: 'wrong'")) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 90e037e29279..3f9bb8de42e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -81,14 +81,16 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH } test("Spark-4959 Attributes are case sensitive when using a select query from a projection") { - sql("create table spark_4959 (col1 string)") - sql("""insert into table spark_4959 select "hi" from src limit 1""") - table("spark_4959").select( - 'col1.as("CaseSensitiveColName"), - 'col1.as("CaseSensitiveColName2")).createOrReplaceTempView("spark_4959_2") - - assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) - assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) + withTable("spark_4959") { + sql("create table spark_4959 (col1 string)") + sql("""insert into table spark_4959 select "hi" from src limit 1""") + table("spark_4959").select( + 'col1.as("CaseSensitiveColName"), + 'col1.as("CaseSensitiveColName2")).createOrReplaceTempView("spark_4959_2") + + assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) + } } private def checkNumScannedPartitions(stmt: String, expectedNumParts: Int): Unit = { @@ -164,16 +166,30 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH |PARTITION (p1='a',p2='c',p3='c',p4='d',p5='e') |SELECT v.id """.stripMargin) - val plan = sql( - s""" - |SELECT * FROM $table - """.stripMargin).queryExecution.sparkPlan - val scan = plan.collectFirst { - case p: HiveTableScanExec => p - }.get + val scan = getHiveTableScanExec(s"SELECT * FROM $table") val numDataCols = scan.relation.dataCols.length scan.rawPartitions.foreach(p => assert(p.getCols.size == numDataCols)) } } } + + test("HiveTableScanExec canonicalization for different orders of partition filters") { + val table = "hive_tbl_part" + withTable(table) { + sql( + s""" + |CREATE TABLE $table (id int) + |PARTITIONED BY (a int, b int) + """.stripMargin) + val scan1 = getHiveTableScanExec(s"SELECT * FROM $table WHERE a = 1 AND b = 2") + val scan2 = getHiveTableScanExec(s"SELECT * FROM $table WHERE b = 2 AND a = 1") + assert(scan1.sameResult(scan2)) + } + } + + private def getHiveTableScanExec(query: String): HiveTableScanExec = { + sql(query).queryExecution.sparkPlan.collectFirst { + case p: HiveTableScanExec => p + }.get + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 479ca1e8def5..8986fb58c646 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo +import test.org.apache.spark.sql.MyDoubleAvg import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec @@ -86,6 +87,18 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { )) } + test("call JAVA UDAF") { + withTempView("temp") { + withUserDefinedFunction("myDoubleAvg" -> false) { + spark.range(1, 10).toDF("value").createOrReplaceTempView("temp") + sql(s"CREATE FUNCTION myDoubleAvg AS '${classOf[MyDoubleAvg].getName}'") + checkAnswer( + spark.sql("SELECT default.myDoubleAvg(value) as my_avg from temp"), + Row(105.0)) + } + } + } + test("non-deterministic children expressions of UDAF") { withTempView("view1") { spark.range(1).selectExpr("id as x", "id as y").createTempView("view1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 4446af2e75e0..6198d4963df3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -73,26 +74,28 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("hive struct udf") { - sql( - """ - |CREATE TABLE hiveUDFTestTable ( - | pair STRUCT - |) - |PARTITIONED BY (partition STRING) - |ROW FORMAT SERDE '%s' - |STORED AS SEQUENCEFILE - """. - stripMargin.format(classOf[PairSerDe].getName)) - - val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile - sql(s""" - ALTER TABLE hiveUDFTestTable - ADD IF NOT EXISTS PARTITION(partition='testUDF') - LOCATION '$location'""") - - sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") - sql("SELECT testUDF(pair) FROM hiveUDFTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") + withTable("hiveUDFTestTable") { + sql( + """ + |CREATE TABLE hiveUDFTestTable ( + | pair STRUCT + |) + |PARTITIONED BY (partition STRING) + |ROW FORMAT SERDE '%s' + |STORED AS SEQUENCEFILE + """. + stripMargin.format(classOf[PairSerDe].getName)) + + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile + sql(s""" + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') + LOCATION '$location'""") + + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") + } } test("Max/Min on named_struct") { @@ -193,7 +196,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { checkAnswer(sql("SELECT percentile_approx(100.0D, array(0.9D, 0.9D)) FROM src LIMIT 1"), sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) - } + } test("UDFIntegerToString") { val testData = spark.sparkContext.parallelize( @@ -403,59 +406,34 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { - Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") + withTempView("testUDF") { + Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") + + def testErrorMsgForFunc(funcName: String, className: String): Unit = { + withUserDefinedFunction(funcName -> true) { + sql(s"CREATE TEMPORARY FUNCTION $funcName AS '$className'") + val message = intercept[AnalysisException] { + sql(s"SELECT $funcName() FROM testUDF") + }.getMessage + assert(message.contains(s"No handler for UDF/UDAF/UDTF '$className'")) + } + } - { // HiveSimpleUDF - sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDFTwoListList() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - } + testErrorMsgForFunc("testUDFTwoListList", classOf[UDFTwoListList].getName) - { // HiveGenericUDF - sql(s"CREATE TEMPORARY FUNCTION testUDFAnd AS '${classOf[GenericUDFOPAnd].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDFAnd() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd") - } + testErrorMsgForFunc("testUDFAnd", classOf[GenericUDFOPAnd].getName) - { // Hive UDAF - sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile") - } + testErrorMsgForFunc("testUDAFPercentile", classOf[UDAFPercentile].getName) - { // AbstractGenericUDAFResolver - sql(s"CREATE TEMPORARY FUNCTION testUDAFAverage AS '${classOf[GenericUDAFAverage].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage") - } + testErrorMsgForFunc("testUDAFAverage", classOf[GenericUDAFAverage].getName) - { - // Hive UDTF - sql(s"CREATE TEMPORARY FUNCTION testUDTFExplode AS '${classOf[GenericUDTFExplode].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDTFExplode() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") + // AbstractGenericUDAFResolver + testErrorMsgForFunc("testUDTFExplode", classOf[GenericUDTFExplode].getName) } - - spark.catalog.dropTempView("testUDF") } test("Hive UDF in group by") { @@ -590,6 +568,76 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("Temp function has dots in the names") { + withUserDefinedFunction("test_avg" -> false, "`default.test_avg`" -> true) { + sql(s"CREATE FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer(sql("SELECT test_avg(1)"), Row(1.0)) + // temp function containing dots in the name + spark.udf.register("default.test_avg", () => { Math.random() + 2}) + assert(sql("SELECT `default.test_avg`()").head().getDouble(0) >= 2.0) + checkAnswer(sql("SELECT test_avg(1)"), Row(1.0)) + } + } + + test("Call the function registered in the not-current database") { + Seq("true", "false").foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + withDatabase("dAtABaSe1") { + sql("CREATE DATABASE dAtABaSe1") + withUserDefinedFunction("dAtABaSe1.test_avg" -> false) { + sql(s"CREATE FUNCTION dAtABaSe1.test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer(sql("SELECT dAtABaSe1.test_avg(1)"), Row(1.0)) + } + val message = intercept[AnalysisException] { + sql("SELECT dAtABaSe1.unknownFunc(1)") + }.getMessage + assert(message.contains("Undefined function: 'unknownFunc'") && + message.contains("nor a permanent function registered in the database 'dAtABaSe1'")) + } + } + } + } + + test("UDTF") { + withUserDefinedFunction("udtf_count2" -> true) { + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } + + test("permanent UDTF") { + withUserDefinedFunction("udtf_count_temp" -> false) { + sql( + s""" + |CREATE FUNCTION udtf_count_temp + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").toURI}' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala new file mode 100644 index 000000000000..5c248b9acd04 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.language.existentials + +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.sql.types._ +import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.util.Utils + +/** + * A separate set of DDL tests that uses Hive 2.1 libraries, which behave a little differently + * from the built-in ones. + */ +@ExtendedHiveTest +class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with BeforeAndAfterEach + with BeforeAndAfterAll { + + // Create a custom HiveExternalCatalog instance with the desired configuration. We cannot + // use SparkSession here since there's already an active on managed by the TestHive object. + private var catalog = { + val warehouse = Utils.createTempDir() + val metastore = Utils.createTempDir() + metastore.delete() + val sparkConf = new SparkConf() + .set(SparkLauncher.SPARK_MASTER, "local") + .set(WAREHOUSE_PATH.key, warehouse.toURI().toString()) + .set(CATALOG_IMPLEMENTATION.key, "hive") + .set(HiveUtils.HIVE_METASTORE_VERSION.key, "2.1") + .set(HiveUtils.HIVE_METASTORE_JARS.key, "maven") + + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.warehouse.dir", warehouse.toURI().toString()) + hadoopConf.set("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=${metastore.getAbsolutePath()};create=true") + // These options are needed since the defaults in Hive 2.1 cause exceptions with an + // empty metastore db. + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") + + new HiveExternalCatalog(sparkConf, hadoopConf) + } + + override def afterEach: Unit = { + catalog.listTables("default").foreach { t => + catalog.dropTable("default", t, true, false) + } + spark.sessionState.catalog.reset() + } + + override def afterAll(): Unit = { + catalog = null + } + + test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) USING json", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))), + hiveCompatible = false) + } + + test("SPARK-21617: ALTER TABLE for Hive-compatible DataSource tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) USING parquet", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType)))) + } + + test("SPARK-21617: ALTER TABLE for Hive tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) STORED AS parquet", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType)))) + } + + test("SPARK-21617: ALTER TABLE with incompatible schema on Hive-compatible table") { + val exception = intercept[AnalysisException] { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 string) USING parquet", + StructType(Array(StructField("c2", IntegerType)))) + } + assert(exception.getMessage().contains("types incompatible with the existing columns")) + } + + private def testAlterTable( + tableName: String, + createTableStmt: String, + updatedSchema: StructType, + hiveCompatible: Boolean = true): Unit = { + spark.sql(createTableStmt) + val oldTable = spark.sessionState.catalog.externalCatalog.getTable("default", tableName) + catalog.createTable(oldTable, true) + catalog.alterTableSchema("default", tableName, updatedSchema) + + val updatedTable = catalog.getTable("default", tableName) + assert(updatedTable.schema.fieldNames === updatedSchema.fieldNames) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index f818e2955546..94384185d190 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -66,4 +67,28 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te } } } + + test("SPARK-20986 Reset table's statistics after PruneFileSourcePartitions rule") { + withTable("tbl") { + spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl") + sql(s"ANALYZE TABLE tbl COMPUTE STATISTICS") + val tableStats = spark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")).stats + assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost") + + val df = sql("SELECT * FROM tbl WHERE p = 1") + val sizes1 = df.queryExecution.analyzed.collect { + case relation: LogicalRelation => relation.catalogTable.get.stats.get.sizeInBytes + } + assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizes1(0) == tableStats.get.sizeInBytes) + + val relations = df.queryExecution.optimizedPlan.collect { + case relation: LogicalRelation => relation + } + assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") + val size2 = relations(0).stats.sizeInBytes + assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) + assert(size2 < tableStats.get.sizeInBytes) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index d535bef4cc78..cc592cf6ca62 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -162,7 +162,12 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { }.head assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch") - assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch") + + // Scanned columns in `HiveTableScanExec` are generated by the `pruneFilterProject` method + // in `SparkPlanner`. This method internally uses `AttributeSet.toSeq`, in which + // the returned output columns are sorted by the names and expression ids. + assert(actualScannedColumns.sorted === expectedScannedColumns.sorted, + "Scanned columns mismatch") val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala similarity index 55% rename from sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala index 157783abc8c2..022cb7177339 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala @@ -15,30 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst +package org.apache.spark.sql.hive.execution -import scala.util.control.NonFatal - -import org.apache.spark.sql.{DataFrame, Dataset, QueryTest} -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +class SQLMetricsSuite extends SQLMetricsTestUtils with TestHiveSingleton { + + test("writing data out metrics: hive") { + testMetricsNonDynamicPartition("hive", "t1") + } -abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { - protected def checkSQL(e: Expression, expectedSQL: String): Unit = { - val actualSQL = e.sql - try { - assert(actualSQL === expectedSQL) - } catch { - case cause: Throwable => - fail( - s"""Wrong SQL generated for the following expression: - | - |${e.prettyName} - | - |$cause - """.stripMargin) + test("writing data out metrics dynamic partition: hive") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + testMetricsDynamicPartition("hive", "hive", "t1") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c944f28d10ef..09c59000b3e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.util.Locale +import java.util.{Locale, Set} import com.google.common.io.Files -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -98,46 +98,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(query1, Row("x1_y1") :: Row("x2_y2") :: Nil) } - test("UDTF") { - withUserDefinedFunction("udtf_count2" -> true) { - sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) - - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) - - checkAnswer( - sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) - } - } - - test("permanent UDTF") { - withUserDefinedFunction("udtf_count_temp" -> false) { - sql( - s""" - |CREATE FUNCTION udtf_count_temp - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").toURI}' - """.stripMargin) - - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) - - checkAnswer( - sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) - } - } - test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") @@ -176,53 +136,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { orders.toDF.createOrReplaceTempView("orders1") orderUpdates.toDF.createOrReplaceTempView("orderupdates1") - sql( - """CREATE TABLE orders( - | id INT, - | make String, - | type String, - | price INT, - | pdate String, - | customer String, - | city String) - |PARTITIONED BY (state STRING, month INT) - |STORED AS PARQUET - """.stripMargin) + withTable("orders", "orderupdates") { + sql( + """CREATE TABLE orders( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) - sql( - """CREATE TABLE orderupdates( - | id INT, - | make String, - | type String, - | price INT, - | pdate String, - | customer String, - | city String) - |PARTITIONED BY (state STRING, month INT) - |STORED AS PARQUET - """.stripMargin) + sql( + """CREATE TABLE orderupdates( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) - sql("set hive.exec.dynamic.partition.mode=nonstrict") - sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") - sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") - checkAnswer( - sql( - """ - |select orders.state, orders.month - |from orders - |join ( - | select distinct orders.state,orders.month - | from orders - | join orderupdates - | on orderupdates.id = orders.id) ao - | on ao.state = orders.state and ao.month = orders.month - """.stripMargin), - (1 to 6).map(_ => Row("CA", 20151))) + checkAnswer( + sql( + """ + |select orders.state, orders.month + |from orders + |join ( + | select distinct orders.state,orders.month + | from orders + | join orderupdates + | on orderupdates.id = orders.id) ao + | on ao.state = orders.state and ao.month = orders.month + """.stripMargin), + (1 to 6).map(_ => Row("CA", 20151))) + } } test("show functions") { - val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted + val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().map(_.unquotedString) val allFunctions = sql("SHOW functions").collect().map(r => r(0)) allBuiltinFunctions.foreach { f => assert(allFunctions.contains(f)) @@ -389,21 +351,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS with WITH clause") { + val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") - - sql( - """ - |CREATE TABLE with_table1 AS - |WITH T AS ( - | SELECT * - | FROM table1 - |) - |SELECT * - |FROM T - """.stripMargin) - val query = sql("SELECT * FROM with_table1") - checkAnswer(query, Row(1, 1) :: Nil) + withTable("with_table1") { + sql( + """ + |CREATE TABLE with_table1 AS + |WITH T AS ( + | SELECT * + | FROM table1 + |) + |SELECT * + |FROM T + """.stripMargin) + val query = sql("SELECT * FROM with_table1") + checkAnswer(query, Row(1, 1) :: Nil) + } } test("explode nested Field") { @@ -451,10 +415,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val catalogTable = sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) relation match { - case LogicalRelation(r: HadoopFsRelation, _, _) => + case LogicalRelation(r: HadoopFsRelation, _, _, _) => if (!isDataSourceTable) { fail( - s"${classOf[CatalogRelation].getCanonicalName} is expected, but found " + + s"${classOf[HiveTableRelation].getCanonicalName} is expected, but found " + s"${HadoopFsRelation.getClass.getCanonicalName}.") } userSpecifiedLocation match { @@ -464,11 +428,11 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } assert(catalogTable.provider.get === format) - case r: CatalogRelation => + case r: HiveTableRelation => if (isDataSourceTable) { fail( s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + - s"${classOf[CatalogRelation].getCanonicalName}.") + s"${classOf[HiveTableRelation].getCanonicalName}.") } userSpecifiedLocation match { case Some(location) => @@ -604,86 +568,90 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS with serde") { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql( - """CREATE TABLE ctas2 - | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" - | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") - | STORED AS RCFile - | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") - | AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin) - - val storageCtas2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas2")).storage - assert(storageCtas2.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(storageCtas2.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(storageCtas2.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) - - sql( - """CREATE TABLE ctas3 - | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' - | STORED AS textfile AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin) - - // the table schema may like (key: integer, value: string) - sql( - """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin) - // do nothing cause the table ctas4 already existed. - sql( - """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin) + withTable("ctas1", "ctas2", "ctas3", "ctas4", "ctas5") { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + sql( + """CREATE TABLE ctas2 + | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + | STORED AS RCFile + | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + | AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin) + + val storageCtas2 = spark.sessionState.catalog. + getTableMetadata(TableIdentifier("ctas2")).storage + assert(storageCtas2.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(storageCtas2.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(storageCtas2.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) - checkAnswer( - sql("SELECT k, value FROM ctas1 ORDER BY k, value"), - sql("SELECT key, value FROM src ORDER BY key, value")) - checkAnswer( - sql("SELECT key, value FROM ctas2 ORDER BY key, value"), sql( - """ + """CREATE TABLE ctas3 + | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' + | STORED AS textfile AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin) + + // the table schema may like (key: integer, value: string) + sql( + """CREATE TABLE IF NOT EXISTS ctas4 AS + | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin) + // do nothing cause the table ctas4 already existed. + sql( + """CREATE TABLE IF NOT EXISTS ctas4 AS + | SELECT key, value FROM src ORDER BY key, value""".stripMargin) + + checkAnswer( + sql("SELECT k, value FROM ctas1 ORDER BY k, value"), + sql("SELECT key, value FROM src ORDER BY key, value")) + checkAnswer( + sql("SELECT key, value FROM ctas2 ORDER BY key, value"), + sql( + """ SELECT key, value FROM src ORDER BY key, value""")) - checkAnswer( - sql("SELECT key, value FROM ctas3 ORDER BY key, value"), - sql( - """ + checkAnswer( + sql("SELECT key, value FROM ctas3 ORDER BY key, value"), + sql( + """ SELECT key, value FROM src ORDER BY key, value""")) - intercept[AnalysisException] { - sql( - """CREATE TABLE ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin) - } - checkAnswer( - sql("SELECT key, value FROM ctas4 ORDER BY key, value"), - sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) - - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin) - val storageCtas5 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas5")).storage - assert(storageCtas5.inputFormat == - Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) - assert(storageCtas5.outputFormat == - Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - assert(storageCtas5.serde == - Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - - - // use the Hive SerDe for parquet tables - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { + intercept[AnalysisException] { + sql( + """CREATE TABLE ctas4 AS + | SELECT key, value FROM src ORDER BY key, value""".stripMargin) + } checkAnswer( - sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value")) + sql("SELECT key, value FROM ctas4 ORDER BY key, value"), + sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) + + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin) + val storageCtas5 = spark.sessionState.catalog. + getTableMetadata(TableIdentifier("ctas5")).storage + assert(storageCtas5.inputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(storageCtas5.outputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(storageCtas5.serde == + Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + + + // use the Hive SerDe for parquet tables + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value")) + } } } @@ -756,40 +724,46 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("double nested data") { - sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) - .toDF().createOrReplaceTempView("nested") - checkAnswer( - sql("SELECT f1.f2.f3 FROM nested"), - Row(1)) + withTable("test_ctas_1234") { + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) + .toDF().createOrReplaceTempView("nested") + checkAnswer( + sql("SELECT f1.f2.f3 FROM nested"), + Row(1)) - sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested") - checkAnswer( - sql("SELECT * FROM test_ctas_1234"), - sql("SELECT * FROM nested").collect().toSeq) + sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested") + checkAnswer( + sql("SELECT * FROM test_ctas_1234"), + sql("SELECT * FROM nested").collect().toSeq) - intercept[AnalysisException] { - sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect() + intercept[AnalysisException] { + sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect() + } } } test("test CTAS") { - sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src") - checkAnswer( - sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), - sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) + withTable("test_ctas_1234") { + sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src") + checkAnswer( + sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), + sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) + } } test("SPARK-4825 save join to table") { - val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() - sql("CREATE TABLE test1 (key INT, value STRING)") - testData.write.mode(SaveMode.Append).insertInto("test1") - sql("CREATE TABLE test2 (key INT, value STRING)") - testData.write.mode(SaveMode.Append).insertInto("test2") - testData.write.mode(SaveMode.Append).insertInto("test2") - sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") - checkAnswer( - table("test"), - sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) + withTable("test1", "test2", "test") { + val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() + sql("CREATE TABLE test1 (key INT, value STRING)") + testData.write.mode(SaveMode.Append).insertInto("test1") + sql("CREATE TABLE test2 (key INT, value STRING)") + testData.write.mode(SaveMode.Append).insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") + sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") + checkAnswer( + table("test"), + sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) + } } test("SPARK-3708 Backticks aren't handled correctly is aliases") { @@ -948,7 +922,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") { sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { - case SubqueryAlias(_, r: CatalogRelation) => // OK + case SubqueryAlias(_, r: HiveTableRelation) => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } @@ -965,14 +939,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("sanity test for SPARK-6618") { - (1 to 100).par.map { i => - val tableName = s"SPARK_6618_table_$i" - sql(s"CREATE TABLE $tableName (col1 string)") - sessionState.catalog.lookupRelation(TableIdentifier(tableName)) - table(tableName) - tables() - sql(s"DROP TABLE $tableName") + val threads: Seq[Thread] = (1 to 10).map { i => + new Thread("test-thread-" + i) { + override def run(): Unit = { + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + sessionState.catalog.lookupRelation(TableIdentifier(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } } + threads.foreach(_.start()) + threads.foreach(_.join(10000)) } test("SPARK-5203 union with different decimal precision") { @@ -1877,14 +1857,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-17108: Fix BIGINT and INT comparison failure in spark sql") { - sql("create table t1(a map>)") - sql("select * from t1 where a[1] is not null") + withTable("t1", "t2", "t3") { + sql("create table t1(a map>)") + sql("select * from t1 where a[1] is not null") - sql("create table t2(a map>)") - sql("select * from t2 where a[1] is not null") + sql("create table t2(a map>)") + sql("select * from t2 where a[1] is not null") - sql("create table t3(a map>)") - sql("select * from t3 where a[1L] is not null") + sql("create table t3(a map>)") + sql("select * from t3 where a[1L] is not null") + } } test("SPARK-17796 Support wildcard character in filename for LOAD DATA LOCAL INPATH") { @@ -2015,4 +1997,57 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(table.filter($"p" === "p1\" and q=\"q1").select($"a"), Row(4)) } } + + test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") { + val table = "test21721" + withTable(table) { + val deleteOnExitField = classOf[FileSystem].getDeclaredField("deleteOnExit") + deleteOnExitField.setAccessible(true) + + val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration) + val setOfPath = deleteOnExitField.get(fs).asInstanceOf[Set[Path]] + + val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() + sql(s"CREATE TABLE $table (key INT, value STRING)") + val pathSizeToDeleteOnExit = setOfPath.size() + + (0 to 10).foreach(_ => testData.write.mode(SaveMode.Append).insertInto(table)) + + assert(setOfPath.size() == pathSizeToDeleteOnExit) + } + } + + test("SPARK-21912 ORC/Parquet table should not create invalid column names") { + Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name => + withTable("t21912") { + Seq("ORC", "PARQUET").foreach { source => + val m = intercept[AnalysisException] { + sql(s"CREATE TABLE t21912(`col$name` INT) USING $source") + }.getMessage + assert(m.contains(s"contains invalid character(s)")) + + val m2 = intercept[AnalysisException] { + sql(s"CREATE TABLE t21912 USING $source AS SELECT 1 `col$name`") + }.getMessage + assert(m2.contains(s"contains invalid character(s)")) + + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { + val m3 = intercept[AnalysisException] { + sql(s"CREATE TABLE t21912(`col$name` INT) USING hive OPTIONS (fileFormat '$source')") + }.getMessage + assert(m3.contains(s"contains invalid character(s)")) + } + } + + // TODO: After SPARK-21929, we need to check ORC, too. + Seq("PARQUET").foreach { source => + sql(s"CREATE TABLE t21912(`col` INT) USING $source") + val m = intercept[AnalysisException] { + sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)") + }.getMessage + assert(m.contains(s"contains invalid character(s)")) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index a20c758a83e7..3f9485dd018b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -232,31 +232,4 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) // scalastyle:on } - - test("null arguments") { - checkAnswer(sql(""" - |select p_mfgr, p_name, p_size, - |sum(null) over(distribute by p_mfgr sort by p_name) as sum, - |avg(null) over(distribute by p_mfgr sort by p_name) as avg - |from part - """.stripMargin), - sql(""" - |select p_mfgr, p_name, p_size, - |null as sum, - |null as avg - |from part - """.stripMargin)) - } - - test("SPARK-16646: LAST_VALUE(FALSE) OVER ()") { - checkAnswer(sql("SELECT LAST_VALUE(FALSE) OVER ()"), Row(false)) - checkAnswer(sql("SELECT LAST_VALUE(FALSE, FALSE) OVER ()"), Row(false)) - checkAnswer(sql("SELECT LAST_VALUE(TRUE, TRUE) OVER ()"), Row(true)) - } - - test("SPARK-16646: FIRST_VALUE(FALSE) OVER ()") { - checkAnswer(sql("SELECT FIRST_VALUE(FALSE) OVER ()"), Row(false)) - checkAnswer(sql("SELECT FIRST_VALUE(FALSE, FALSE) OVER ()"), Row(false)) - checkAnswer(sql("SELECT FIRST_VALUE(TRUE, TRUE) OVER ()"), Row(true)) - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index 222c24927a76..de6f0d67f173 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -45,7 +45,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { var maybeRelation: Option[HadoopFsRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) => + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => maybeRelation = Some(orcRelation) filters }.flatten.reduceLeftOption(_ && _) @@ -89,7 +89,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { var maybeRelation: Option[HadoopFsRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) => + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => maybeRelation = Some(orcRelation) filters }.flatten.reduceLeftOption(_ && _) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 8c855730c31f..60ccd996d6d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHive._ @@ -475,7 +475,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } else { queryExecution.analyzed.collectFirst { - case _: CatalogRelation => () + case _: HiveTableRelation => () }.getOrElse { fail(s"Expecting no conversion from orc to data sources, " + s"but got:\n$queryExecution") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 6bfb88c0c1af..781de6631f32 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -22,8 +22,8 @@ import java.io.File import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -149,11 +149,11 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - assert(new OrcOptions(Map("Orc.Compress" -> "NONE")).compressionCodec == "NONE") + val conf = sqlContext.sessionState.conf + assert(new OrcOptions(Map("Orc.Compress" -> "NONE"), conf).compressionCodec == "NONE") } test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client val location = Utils.createTempDir() val uri = location.toURI try { @@ -195,6 +195,30 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA Utils.deleteRecursively(location) } } + + test("SPARK-21839: Add SQL config for ORC compression") { + val conf = sqlContext.sessionState.conf + // Test if the default of spark.sql.orc.compression.codec is snappy + assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "SNAPPY") + + // OrcOptions's parameters have a higher priority than SQL configuration. + // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec` + withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") { + assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "NONE") + val map1 = Map("orc.compress" -> "zlib") + val map2 = Map("orc.compress" -> "zlib", "compression" -> "lzo") + assert(new OrcOptions(map1, conf).compressionCodec == "ZLIB") + assert(new OrcOptions(map2, conf).compressionCodec == "LZO") + } + + // Test all the valid options of spark.sql.orc.compression.codec + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "ZLIB", "LZO").foreach { c => + withSQLConf(SQLConf.ORC_COMPRESSION.key -> c) { + val expected = if (c == "UNCOMPRESSED") "NONE" else c + assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == expected) + } + } + } } class OrcSourceSuite extends OrcSuite { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 23f21e6b9931..740e0837350c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -285,7 +285,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _, _) => // OK case _ => fail( "test_parquet_ctas should be converted to " + s"${classOf[HadoopFsRelation ].getCanonicalName }") @@ -370,7 +370,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { assertResult(2) { analyzed.collect { - case r @ LogicalRelation(_: HadoopFsRelation, _, _) => r + case r @ LogicalRelation(_: HadoopFsRelation, _, _, _) => r }.size } } @@ -379,7 +379,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { def collectHadoopFsRelation(df: DataFrame): HadoopFsRelation = { val plan = df.queryExecution.analyzed plan.collectFirst { - case LogicalRelation(r: HadoopFsRelation, _, _) => r + case LogicalRelation(r: HadoopFsRelation, _, _, _) => r }.getOrElse { fail(s"Expecting a HadoopFsRelation 2, but got:\n$plan") } @@ -459,7 +459,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { // Converted test_parquet should be cached. getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case LogicalRelation(_: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + @@ -812,7 +812,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } } else { queryExecution.analyzed.collectFirst { - case _: CatalogRelation => + case _: HiveTableRelation => }.getOrElse { fail(s"Expecting no conversion from parquet to data sources, " + s"but got:\n$queryExecution") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index d23b66a5300e..80aff446bc24 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -22,9 +22,6 @@ import java.io.File import scala.util.Random import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ @@ -783,52 +780,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - test("SPARK-8578 specified custom output committer will not be used to append data") { - withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> - classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { - val extraOptions = Map[String, String]( - SQLConf.OUTPUT_COMMITTER_CLASS.key -> classOf[AlwaysFailOutputCommitter].getName, - // Since Parquet has its own output committer setting, also set it - // to AlwaysFailParquetOutputCommitter at here. - "spark.sql.parquet.output.committer.class" -> - classOf[AlwaysFailParquetOutputCommitter].getName - ) - - val df = spark.range(1, 10).toDF("i") - withTempPath { dir => - df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) - // Because there data already exists, - // this append should succeed because we will use the output committer associated - // with file format and AlwaysFailOutputCommitter will not be used. - df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) - checkAnswer( - spark.read - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .options(extraOptions) - .load(dir.getCanonicalPath), - df.union(df)) - - // This will fail because AlwaysFailOutputCommitter is used when we do append. - intercept[Exception] { - df.write.mode("overwrite") - .options(extraOptions).format(dataSourceName).save(dir.getCanonicalPath) - } - } - withTempPath { dir => - // Because there is no existing data, - // this append will fail because AlwaysFailOutputCommitter is used when we do append - // and there is no existing data. - intercept[Exception] { - df.write.mode("append") - .options(extraOptions) - .format(dataSourceName) - .save(dir.getCanonicalPath) - } - } - } - } - test("SPARK-8887: Explicitly define which data types can be used as dynamic partition columns") { val df = Seq( (1, "v1", Array(1, 2, 3), Map("k1" -> "v1"), Tuple2(1, "4")), @@ -898,27 +849,3 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } } - -// This class is used to test SPARK-8578. We should not use any custom output committer when -// we actually append data to an existing dir. -class AlwaysFailOutputCommitter( - outputPath: Path, - context: TaskAttemptContext) - extends FileOutputCommitter(outputPath, context) { - - override def commitJob(context: JobContext): Unit = { - sys.error("Intentional job commitment failure for testing purpose.") - } -} - -// This class is used to test SPARK-8578. We should not use any custom output committer when -// we actually append data to an existing dir. -class AlwaysFailParquetOutputCommitter( - outputPath: Path, - context: TaskAttemptContext) - extends ParquetOutputCommitter(outputPath, context) { - - override def commitJob(context: JobContext): Unit = { - sys.error("Intentional job commitment failure for testing purpose.") - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9f4009bfe402..60a4638f610b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -103,7 +103,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { // `Cast`ed values are always of internal types (e.g. UTF8String instead of String) Cast(Literal(value), dataType).eval() }) - }.filter(predicate).map(projection) + }.filter(predicate.eval).map(projection) // Appends partition values val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes diff --git a/dev/change-version-to-2.11.sh b/sql/mkdocs.yml old mode 100755 new mode 100644 similarity index 59% rename from dev/change-version-to-2.11.sh rename to sql/mkdocs.yml index 4ccfeef09fd0..c34c891bb9e4 --- a/dev/change-version-to-2.11.sh +++ b/sql/mkdocs.yml @@ -1,23 +1,19 @@ -#!/usr/bin/env bash - -# # Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with +# contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# - -# This script exists for backwards compability. Use change-scala-version.sh instead. -echo "This script is deprecated. Please instead run: change-scala-version.sh 2.11" -$(dirname $0)/change-scala-version.sh 2.11 +site_name: Spark SQL, Built-in Functions +theme: readthedocs +pages: + - 'Functions': 'index.md' diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 2803cad8095d..00c59728748f 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -56,7 +56,7 @@ public abstract class WriteAheadLog { public abstract void clean(long threshTime, boolean waitForCompletion); /** - * Close this log and release any resources. + * Close this log and release any resources. It must be idempotent. */ public abstract void close(); } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 5cbad8bf3ce6..b8c780db07c9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -55,6 +55,9 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.master", "spark.yarn.keytab", "spark.yarn.principal", + "spark.yarn.credentials.file", + "spark.yarn.credentials.renewalTime", + "spark.yarn.credentials.updateTime", "spark.ui.filters") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index a34f6c73fea8..027403816f53 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -458,7 +458,7 @@ class StreamingContext private[streaming] ( queue: Queue[RDD[T]], oneAtATime: Boolean = true ): InputDStream[T] = { - queueStream(queue, oneAtATime, sc.makeRDD(Seq[T](), 1)) + queueStream(queue, oneAtATime, sc.makeRDD(Seq.empty[T], 1)) } /** @@ -596,7 +596,7 @@ class StreamingContext private[streaming] ( } logDebug("Adding shutdown hook") // force eager creation of logger shutdownHookRef = ShutdownHookManager.addShutdownHook( - StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + StreamingContext.SHUTDOWN_HOOK_PRIORITY)(() => stopOnShutdown()) // Registering Streaming Metrics at the start of the StreamingContext assert(env.metricsSystem != null) env.metricsSystem.registerSource(streamingSource) @@ -659,8 +659,7 @@ class StreamingContext private[streaming] ( def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { var shutdownHookRefToRemove: AnyRef = null if (LiveListenerBus.withinListenerThread.value) { - throw new SparkException( - s"Cannot stop StreamingContext within listener thread of ${LiveListenerBus.name}") + throw new SparkException(s"Cannot stop StreamingContext within listener bus thread.") } synchronized { // The state should always be Stopped after calling `stop()`, even if we haven't started yet diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a0a40fcee26d..4a0ec31b5f3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -153,7 +153,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def context(): StreamingContext = dstream.context /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[R](f: JFunction[T, R]): JavaDStream[R] = { + def map[U](f: JFunction[T, U]): JavaDStream[U] = { new JavaDStream(dstream.map(f)(fakeClassTag))(fakeClassTag) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 905b1c52afa6..b8a5a96faf15 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -164,6 +164,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( /** Clear the old time-to-files mappings along with old RDDs */ protected[streaming] override def clearMetadata(time: Time) { + super.clearMetadata(time) batchTimeToSelectedFiles.synchronized { val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) batchTimeToSelectedFiles --= oldFiles.keys diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index f38c1e799659..dcb51d72fa58 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -389,6 +389,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. + * In every batch the updateFunc will be called for each state even if there are no new values. * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. @@ -403,6 +404,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. + * In every batch the updateFunc will be called for each state even if there are no new values. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. @@ -419,6 +421,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of the key. + * In every batch the updateFunc will be called for each state even if there are no new values. * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. @@ -440,6 +443,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. + * In every batch the updateFunc will be called for each state even if there are no new values. * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. Note, that this function may generate a different * tuple with a different key than the input key. Therefore keys may be removed @@ -464,6 +468,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of the key. + * In every batch the updateFunc will be called for each state even if there are no new values. * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. @@ -487,6 +492,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. + * In every batch the updateFunc will be called for each state even if there are no new values. * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. Note, that this function may generate a different * tuple with a different key than the input key. Therefore keys may be removed @@ -513,6 +519,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of the key. + * In every batch the updateFunc will be called for each state even if there are no new values. * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 5bf1dabf08f4..d1a5e9179370 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -76,7 +76,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // Re-apply the update function to the old state RDD val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, S)]) => { - val i = iterator.map(t => (t._1, Seq[V](), Option(t._2))) + val i = iterator.map(t => (t._1, Seq.empty[V], Option(t._2))) updateFuncLocal(validTime, i) } val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index d91a64df321a..31a88730d163 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -19,8 +19,8 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.StorageLevel diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index f5c8a88f42af..27644a645727 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -18,8 +18,8 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index bd7ab0b9bf5e..6f130c803f31 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -165,11 +165,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (isTrackerStarted) { - // First, stop the receivers - trackerState = Stopping + val isStarted: Boolean = isTrackerStarted + trackerState = Stopping + if (isStarted) { if (!skipReceiverLaunch) { - // Send the stop signal to all the receivers + // First, stop the receivers. Send the stop signal to all the receivers endpoint.askSync[Boolean](StopAllReceivers) // Wait for the Spark job that runs the receivers to be over @@ -194,17 +194,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Finally, stop the endpoint ssc.env.rpcEnv.stop(endpoint) endpoint = null - receivedBlockTracker.stop() - logInfo("ReceiverTracker stopped") - trackerState = Stopped - } else if (isTrackerInitialized) { - trackerState = Stopping - // `ReceivedBlockTracker` is open when this instance is created. We should - // close this even if this `ReceiverTracker` is not started. - receivedBlockTracker.stop() - logInfo("ReceiverTracker stopped") - trackerState = Stopped } + + // `ReceivedBlockTracker` is open when this instance is created. We should + // close this even if this `ReceiverTracker` is not started. + receivedBlockTracker.stop() + logInfo("ReceiverTracker stopped") + trackerState = Stopped } /** Allocate all unallocated blocks to the given batch. */ @@ -453,9 +449,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false endpoint.send(StartAllReceivers(receivers)) } - /** Check if tracker has been marked for initiated */ - private def isTrackerInitialized: Boolean = trackerState == Initialized - /** Check if tracker has been marked for starting */ private def isTrackerStarted: Boolean = trackerState == Started diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 5fb0bd057d0f..6a70bf7406b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -76,7 +76,7 @@ private[streaming] class StreamingListenerBus(sparkListenerBus: LiveListenerBus) * forward them to StreamingListeners. */ def start(): Unit = { - sparkListenerBus.addListener(this) // for getting callbacks on spark events + sparkListenerBus.addToStatusQueue(this) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 70b4bb466c46..f1070e9029cb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -25,7 +25,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) protected def columns: Seq[Node] = { Batch Time - Input Size + Records Scheduling Delay {SparkUIUtils.tooltip("Time taken by Streaming scheduler to submit jobs of a batch", "top")} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index f55af6a5cc35..69e15655ad79 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -304,7 +304,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized { - val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { + // stripXSS is called first to remove suspicious characters used in XSS attacks + val batchTime = + Option(SparkUIUtils.stripXSS(request.getParameter("id"))).map(id => Time(id.toLong)) + .getOrElse { throw new IllegalArgumentException(s"Missing id parameter") } val formattedBatchTime = diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 35f0166ed0cf..71b86d16866e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -60,7 +61,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp private val walWriteQueue = new LinkedBlockingQueue[Record]() // Whether the writer thread is active - @volatile private var active: Boolean = true + private val active: AtomicBoolean = new AtomicBoolean(true) private val buffer = new ArrayBuffer[Record]() private val batchedWriterThread = startBatchedWriterThread() @@ -72,7 +73,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { val promise = Promise[WriteAheadLogRecordHandle]() val putSuccessfully = synchronized { - if (active) { + if (active.get()) { walWriteQueue.offer(Record(byteBuffer, time, promise)) true } else { @@ -121,9 +122,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp */ override def close(): Unit = { logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") - synchronized { - active = false - } + if (!active.getAndSet(false)) return batchedWriterThread.interrupt() batchedWriterThread.join() while (!walWriteQueue.isEmpty) { @@ -138,7 +137,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp private def startBatchedWriterThread(): Thread = { val thread = new Thread(new Runnable { override def run(): Unit = { - while (active) { + while (active.get()) { try { flushRecords() } catch { @@ -166,7 +165,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp } try { var segment: WriteAheadLogRecordHandle = null - if (buffer.length > 0) { + if (buffer.nonEmpty) { logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") // threads may not be able to add items in order by time val sortedByTime = buffer.sortBy(_.time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 845f554308c4..d6e15cfdd272 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -189,7 +189,9 @@ private[streaming] class FileBasedWriteAheadLog( val f = Future { deleteFile(logInfo) }(executionContext) if (waitForCompletion) { import scala.concurrent.duration._ + // scalastyle:off awaitready Await.ready(f, 1 second) + // scalastyle:on awaitready } } catch { case e: RejectedExecutionException => @@ -203,10 +205,12 @@ private[streaming] class FileBasedWriteAheadLog( /** Stop the manager, close any open log writer */ def close(): Unit = synchronized { - if (currentLogWriter != null) { - currentLogWriter.close() + if (!executionContext.isShutdown) { + if (currentLogWriter != null) { + currentLogWriter.close() + } + executionContext.shutdown() } - executionContext.shutdown() logInfo("Stopped write ahead log manager") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 408936653c79..eb9996ece377 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -63,7 +63,6 @@ object RawTextHelper { var i = 0 var len = 0 - var done = false var value: (String, Long) = null var swap: (String, Long) = null var count = 0 diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index ae44fd07ac55..0c4a64ccc513 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.streaming.api.java.{JavaDStreamLike, JavaDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaDStreamLike, JavaStreamingContext} /** Exposes streaming test functionality in a Java-friendly way. */ trait JavaTestBase extends TestSuiteBase { @@ -35,7 +35,7 @@ trait JavaTestBase extends TestSuiteBase { def attachTestInputStream[T]( ssc: JavaStreamingContext, data: JList[JList[T]], - numPartitions: Int) = { + numPartitions: Int): JavaDStream[T] = { val seqData = data.asScala.map(_.asScala) implicit val cm: ClassTag[T] = @@ -47,9 +47,9 @@ trait JavaTestBase extends TestSuiteBase { /** * Attach a provided stream to it's associated StreamingContext as a * [[org.apache.spark.streaming.TestOutputStream]]. - **/ + */ def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]]( - dstream: JavaDStreamLike[T, This, R]) = { + dstream: JavaDStreamLike[T, This, R]): Unit = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val ostream = new TestOutputStreamWithPartitions(dstream.dstream) @@ -90,10 +90,10 @@ trait JavaTestBase extends TestSuiteBase { } object JavaTestUtils extends JavaTestBase { - override def maxWaitTimeMillis = 20000 + override def maxWaitTimeMillis: Int = 20000 } object JavaCheckpointTestUtils extends JavaTestBase { - override def actuallyWait = true + override def actuallyWait: Boolean = true } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index a3062ac94614..6f62c7a88dc3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -25,12 +25,11 @@ import scala.reflect.ClassTag import org.scalatest.concurrent.Eventually.eventually -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{HashPartitioner, SparkConf, SparkException} import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} import org.apache.spark.util.{Clock, ManualClock} -import org.apache.spark.HashPartitioner class BasicOperationsSuite extends TestSuiteBase { test("map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 3c4a2716caf9..fe65353b9d50 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -50,7 +50,6 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) extends SparkFunSuite with BeforeAndAfter with Matchers - with LocalSparkContext with Logging { import WriteAheadLogBasedBlockHandler._ @@ -89,10 +88,9 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) - sc = new SparkContext("local", "test", conf) blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) storageLevel = StorageLevel.MEMORY_ONLY_SER blockManager = createBlockManager(blockManagerSize, conf) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 1b1e21f6e5ba..5fc626c1f78b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -24,8 +24,8 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf @@ -36,7 +36,7 @@ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.util.Utils /** Testsuite for testing the network receiver behavior */ -class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { +class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { test("receiver life cycle") { @@ -60,6 +60,8 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { // Verify that the receiver intercept[Exception] { + // Necessary to make failAfter interrupt awaitTermination() in ScalaTest 3.x + implicit val signaler: Signaler = ThreadSignaler failAfter(200 millis) { executingThread.join() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index eb996c93ff38..5810e73f4098 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -27,8 +27,8 @@ import scala.collection.mutable.Queue import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ @@ -42,7 +42,7 @@ import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { +class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeLimits with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName @@ -406,6 +406,8 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo // test whether awaitTermination() does not exit if not time is given val exception = intercept[Exception] { + // Necessary to make failAfter interrupt awaitTermination() in ScalaTest 3.x + implicit val signaler: Signaler = ThreadSignaler failAfter(1000 millis) { ssc.awaitTermination() throw new Exception("Did not wait for stop") @@ -573,8 +575,6 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("getActive and getActiveOrCreate") { require(StreamingContext.getActive().isEmpty, "context exists from before") - sc = new SparkContext(conf) - var newContextCreated = false def creatingFunc(): StreamingContext = { @@ -601,6 +601,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo // getActiveOrCreate should create new context and getActive should return it only // after starting the context testGetActiveOrCreate { + sc = new SparkContext(conf) ssc = StreamingContext.getActiveOrCreate(creatingFunc _) assert(ssc != null, "no context created") assert(newContextCreated === true, "new context not created") @@ -620,6 +621,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo // getActiveOrCreate and getActive should return independently created context after activating testGetActiveOrCreate { + sc = new SparkContext(conf) ssc = creatingFunc() // Create assert(StreamingContext.getActive().isEmpty, "new initialized context returned before starting") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index e7cec999c219..f2204a187093 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -121,11 +121,11 @@ class UISeleniumSuite h4Text.exists(_.matches("Completed Batches \\(last \\d+ out of \\d+\\)")) should be (true) findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { - List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", + List("Batch Time", "Records", "Scheduling Delay (?)", "Processing Time (?)", "Output Ops: Succeeded/Total", "Status") } findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { - List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", + List("Batch Time", "Records", "Scheduling Delay (?)", "Processing Time (?)", "Total Delay (?)", "Output Ops: Succeeded/Total") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index b70383ecde4d..898da4445e46 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -21,12 +21,12 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.language.reflectiveCalls import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ +import org.scalatest.concurrent.{Signaler, ThreadSignaler} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.TimeLimits._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} @@ -35,6 +35,7 @@ import org.apache.spark.util.ManualClock class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { + implicit val defaultSignaler: Signaler = ThreadSignaler private val blockIntervalMs = 10 private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") @volatile private var blockGenerator: BlockGenerator = null @@ -202,21 +203,17 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { test("block push errors are reported") { val listener = new TestBlockGeneratorListener { - @volatile var errorReported = false override def onPushBlock( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { throw new SparkException("test") } - override def onError(message: String, throwable: Throwable): Unit = { - errorReported = true - } } blockGenerator = new BlockGenerator(listener, 0, conf) blockGenerator.start() - assert(listener.errorReported === false) + assert(listener.onErrorCalled === false) blockGenerator.addData(1) eventually(timeout(1 second), interval(10 milliseconds)) { - assert(listener.errorReported === true) + assert(listener.onErrorCalled === true) } blockGenerator.stop() } @@ -243,12 +240,15 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { @volatile var onGenerateBlockCalled = false @volatile var onAddDataCalled = false @volatile var onPushBlockCalled = false + @volatile var onErrorCalled = false override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { pushedData.addAll(arrayBuffer.asJava) onPushBlockCalled = true } - override def onError(message: String, throwable: Throwable): Unit = {} + override def onError(message: String, throwable: Throwable): Unit = { + onErrorCalled = true + } override def onGenerateBlock(blockId: StreamBlockId): Unit = { onGenerateBlockCalled = true } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala index 1d2bf35a6d45..8d81b582e4d3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Matchers.{eq => meq} import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, PrivateMethodTester} import org.scalatest.concurrent.Eventually.{eventually, timeout} -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.scalatest.time.SpanSugar._ import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkFunSuite} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index df122ac090c3..c206d3169d77 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -57,6 +57,8 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } finally { tracker.stop(false) + // Make sure it is idempotent. + tracker.stop(false) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 4bec52b9fe4f..4a2549fc0a96 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -36,7 +36,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.mock.MockitoSugar +import org.scalatest.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.streaming.scheduler._ @@ -140,6 +140,8 @@ abstract class CommonWriteAheadLogTests( } } writeAheadLog.close() + // Make sure it is idempotent. + writeAheadLog.close() } test(testPrefix + "handling file errors while reading rotating logs") { @@ -482,7 +484,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // we make the write requests in separate threads so that we don't block the test thread private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { val p = Promise[Unit]() - p.completeWith(Future { + p.completeWith(Future[Unit] { val v = wal.write(event, time) assert(v === walHandle) }(walBatchingExecutionContext)) diff --git a/tools/pom.xml b/tools/pom.xml index 7ba4dc9842f1..37427e8da62d 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -44,7 +44,7 @@ org.clapper classutil_${scala.binary.version} - 1.0.6 + 1.1.2